diff --git a/cuda/include/gkr.cuh b/cuda/include/gkr.cuh index 88a65d8..e4c9970 100644 --- a/cuda/include/gkr.cuh +++ b/cuda/include/gkr.cuh @@ -6,4 +6,98 @@ extern "C" void gen_eq_evals(qm31 v, qm31 *y, uint32_t y_size, qm31 *evals, uint32_t evals_size); +template +struct Fraction{ + T numerator; + qm31 denominator; + + __device__ Fraction() : numerator(T{}), denominator(qm31{}) {} + __device__ Fraction(T num, qm31 denom) : numerator(num), denominator(denom) {} +}; + +__device__ Fraction add_fraction(Fraction lhs, Fraction rhs); +__device__ Fraction add_fraction(Fraction lhs, Fraction rhs); + +template +struct Reciprocal{ + T x; + + __device__ Reciprocal() : x(T{}) {} + __device__ Reciprocal(T x) : x(x) {} +}; + +__device__ Fraction add_reciprocal(Reciprocal lhs, Reciprocal rhs); + +extern "C" { + void next_grand_product_layer( + qm31 *layer, + uint32_t layer_size, + qm31 *next_layer, + uint32_t next_layer_size + ); + + void next_logup_generic_layer( + qm31 *numerators, + qm31 *denominators, + uint32_t size, + qm31 *next_numerators, + qm31 *next_denominators, + uint32_t next_size + ); + + void next_logup_multiplicities_layer( + m31 *numerators, + qm31 *denominators, + uint32_t size, + qm31 *next_numerators, + qm31 *next_denominators, + uint32_t next_size + ); + + void next_logup_singles_layer( + qm31 *denominators, + uint32_t size, + qm31 *next_numerators, + qm31 *next_denominators, + uint32_t next_size + ); + + void eval_grand_product_sum( + qm31 *eq_evals, + qm31 *input_layer, + uint32_t n_terms, + qm31 *eval_at_0, + qm31 *eval_at_2 + ); + + void eval_logup_generic_sum( + qm31 *eq_evals, + qm31 *numerators, + qm31 *denominators, + uint32_t n_terms, + qm31 lambda, + qm31 *eval_at_0, + qm31 *eval_at_2 + ); + + void eval_logup_multiplicities_sum( + qm31 *eq_evals, + m31 *numerators, + qm31 *denominators, + uint32_t n_terms, + qm31 lambda, + qm31 *eval_at_0, + qm31 *eval_at_2 + ); + + void eval_logup_singles_sum( + qm31 *eq_evals, + qm31 *denominators, + uint32_t n_terms, + qm31 lambda, + qm31 *eval_at_0, + qm31 *eval_at_2 + ); +} + #endif // GKR_H diff --git a/cuda/src/gkr.cu b/cuda/src/gkr.cu index af52b76..88b98a1 100644 --- a/cuda/src/gkr.cu +++ b/cuda/src/gkr.cu @@ -1,3 +1,4 @@ +#include #include "utils.cuh" #include "gkr.cuh" @@ -36,3 +37,652 @@ void gen_eq_evals(qm31 v, qm31 *y, uint32_t y_size, qm31 *evals, uint32_t evals_ cudaFree(factors_device); } + +__device__ Fraction add_fraction(Fraction lhs, Fraction rhs) { + qm31 numerator = add(mul(lhs.numerator, rhs.denominator), mul(rhs.numerator, lhs.denominator)); + qm31 denominator = mul(lhs.denominator, rhs.denominator); + return Fraction {numerator, denominator}; +} + +__device__ Fraction add_fraction(Fraction lhs, Fraction rhs) { + qm31 numerator = add(mul(lhs.numerator, rhs.denominator), mul(rhs.numerator, lhs.denominator)); + qm31 denominator = mul(lhs.denominator, rhs.denominator); + return Fraction(numerator, denominator); +} + +__device__ Fraction add_reciprocal(Reciprocal lhs, Reciprocal rhs) { + // `1/a + 1/b = (a + b)/(a * b)` + return Fraction(add(lhs.x, rhs.x), mul(lhs.x, rhs.x)); +} + +// Function performs a tree-style reduction for efficient value aggregation +// Size of output is the number of blocks +__global__ void reduction_kernel(qm31 *input, uint32_t input_size, qm31 *output) { + unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ qm31 shared_eval[1024]; + if (tid < input_size) { + shared_eval[threadIdx.x] = input[tid]; + } + else { + shared_eval[threadIdx.x] = {0, 0, 0, 0}; + } + __syncthreads(); + + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + shared_eval[threadIdx.x] = add(shared_eval[threadIdx.x], shared_eval[threadIdx.x + s]); + } + __syncthreads(); + } + + if (threadIdx.x == 0) output[blockIdx.x] = shared_eval[0]; +} + +__global__ void next_grand_product_layer_kernel(qm31 *layer, uint32_t layer_size, qm31 *next_layer, uint32_t next_layer_size) { + unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < next_layer_size) { + next_layer[tid] = mul(layer[tid * 2], layer[tid * 2 + 1]); + } +} + +void next_grand_product_layer(qm31 *layer, uint32_t layer_size, qm31 *next_layer, uint32_t next_layer_size) { + const unsigned int BLOCK_SIZE = 1024; + const unsigned int NUM_BLOCKS = (next_layer_size + BLOCK_SIZE - 1) / BLOCK_SIZE; + next_grand_product_layer_kernel<<>>(layer, layer_size, next_layer, next_layer_size); + cudaDeviceSynchronize(); +} + +__global__ void next_logup_generic_layer_kernel( + qm31 *numerators, + qm31 *denominators, + uint32_t size, + qm31 *next_numerators, + qm31 *next_denominators, + uint32_t next_size +) { + unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < next_size) { + Fraction a = Fraction(numerators[tid * 2], denominators[tid * 2]); + Fraction b = Fraction(numerators[tid * 2 + 1], denominators[tid * 2 + 1]); + + Fraction res = add_fraction(a, b); + next_numerators[tid] = res.numerator; + next_denominators[tid] = res.denominator; + } +} + +void next_logup_generic_layer( + qm31 *numerators, + qm31 *denominators, + uint32_t size, + qm31 *next_numerators, + qm31 *next_denominators, + uint32_t next_size +) { + const unsigned int BLOCK_SIZE = 1024; + const unsigned int NUM_BLOCKS = (next_size + BLOCK_SIZE - 1) / BLOCK_SIZE; + + next_logup_generic_layer_kernel<<>>(numerators, denominators, size, next_numerators, next_denominators, next_size); + cudaDeviceSynchronize(); +} + +__global__ void next_logup_multiplicities_layer_kernel( + m31 *numerators, + qm31 *denominators, + uint32_t size, + qm31 *next_numerators, + qm31 *next_denominators, + uint32_t next_size +) { + unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < next_size) { + Fraction a = Fraction(numerators[tid * 2], denominators[tid * 2]); + Fraction b = Fraction(numerators[tid * 2 + 1], denominators[tid * 2 + 1]); + + Fraction res = add_fraction(a, b); + next_numerators[tid] = res.numerator; + next_denominators[tid] = res.denominator; + } +} + +void next_logup_multiplicities_layer( + m31 *numerators, + qm31 *denominators, + uint32_t size, + qm31 *next_numerators, + qm31 *next_denominators, + uint32_t next_size +) { + const unsigned int BLOCK_SIZE = 1024; + const unsigned int NUM_BLOCKS = (next_size + BLOCK_SIZE - 1) / BLOCK_SIZE; + + next_logup_multiplicities_layer_kernel<<>>(numerators, denominators, size, next_numerators, next_denominators, next_size); + cudaDeviceSynchronize(); +} + +__global__ void next_logup_singles_layer_kernel( + qm31 *denominators, + uint32_t size, + qm31 *next_numerators, + qm31 *next_denominators, + uint32_t next_size +) { + unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < next_size) { + Reciprocal even = Reciprocal(denominators[tid * 2]); + Reciprocal odd = Reciprocal(denominators[tid * 2 + 1]); + Fraction res = add_reciprocal(even, odd); + + next_numerators[tid] = res.numerator; + next_denominators[tid] = res.denominator; + } +} + +void next_logup_singles_layer( + qm31 *denominators, + uint32_t size, + qm31 *next_numerators, + qm31 *next_denominators, + uint32_t next_size +) { + const unsigned int BLOCK_SIZE = 1024; + const unsigned int NUM_BLOCKS = (next_size + BLOCK_SIZE - 1) / BLOCK_SIZE; + + next_logup_singles_layer_kernel<<>>(denominators, size, next_numerators, next_denominators, next_size); + cudaDeviceSynchronize(); +} + +__global__ void eval_grand_product_sum_kernel( + qm31 *eq_evals, + qm31 *input_layer, + uint32_t n_terms, + qm31 *eval_at_0, + qm31 *eval_at_2 +) { + unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x; + + // todo: specify size of shared memory + __shared__ qm31 shared_eval_0[1024]; + __shared__ qm31 shared_eval_2[1024]; + + shared_eval_0[threadIdx.x] = {0, 0, 0, 0}; + shared_eval_2[threadIdx.x] = {0, 0, 0, 0}; + __syncthreads(); + + if (tid < n_terms) { + qm31 inp_at_r0i0 = input_layer[tid * 2]; + qm31 inp_at_r0i1 = input_layer[tid * 2 + 1]; + qm31 inp_at_r1i0 = input_layer[(n_terms + tid) * 2]; + qm31 inp_at_r1i1 = input_layer[(n_terms + tid) * 2 + 1]; + + qm31 inp_at_r2i0 = sub(add(inp_at_r1i0, inp_at_r1i0), inp_at_r0i0); // inp_at_r2i0 + qm31 inp_at_r2i1 = sub(add(inp_at_r1i1, inp_at_r1i1), inp_at_r0i1); // inp_at_r2i1 + + qm31 prod_at_r2i = mul(inp_at_r2i0, inp_at_r2i1); // prod_at_r2i + qm31 prod_at_r0i = mul(inp_at_r0i0, inp_at_r0i1); // prod_at_r0i + + shared_eval_0[threadIdx.x] = mul(eq_evals[tid], prod_at_r0i); // eq_eval_at_0i * prod_at_r0i + shared_eval_2[threadIdx.x] = mul(eq_evals[tid], prod_at_r2i); // eq_eval_at_0i * prod_at_r2i + } + __syncthreads(); + + // Perform intra-block reduction in shared memory + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + shared_eval_0[threadIdx.x] = add(shared_eval_0[threadIdx.x], shared_eval_0[threadIdx.x + s]); + shared_eval_2[threadIdx.x] = add(shared_eval_2[threadIdx.x], shared_eval_2[threadIdx.x + s]); + } + __syncthreads(); + } + + // Set global memory for each block id, grab the reduced shared thread id w.r.t. each block + if (threadIdx.x == 0) { + eval_at_0[blockIdx.x] = shared_eval_0[threadIdx.x]; + eval_at_2[blockIdx.x] = shared_eval_2[threadIdx.x]; + } +} + +void eval_grand_product_sum( + qm31 *eq_evals, + qm31 *input_layer, + uint32_t n_terms, + qm31 *eval_at_0, + qm31 *eval_at_2 +) { + const unsigned int BLOCK_SIZE = 1024; + const unsigned int NUM_BLOCKS = (n_terms + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // Arrays for intra-block reduction + qm31 *eval_at_0_temp_d; + qm31 *eval_at_2_temp_d; + + cudaMalloc((void **)&eval_at_0_temp_d, sizeof(qm31) * NUM_BLOCKS); + cudaMalloc((void **)&eval_at_2_temp_d, sizeof(qm31) * NUM_BLOCKS); + + size_t shared_mem_size = 2 * BLOCK_SIZE * sizeof(qm31); + eval_grand_product_sum_kernel<<>>( + eq_evals, + input_layer, + n_terms, + eval_at_0_temp_d, + eval_at_2_temp_d + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + fprintf(stderr, "eval_grand_product_sum_kernel launch error: %s\n", cudaGetErrorString(err)); + } + + // Synchronize to catch any runtime errors + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "eval_grand_product_sum_kernel execution error: %s\n", cudaGetErrorString(err)); + } + + // Post intra-block reduction + const int BLOCK_REDUCTION_SIZE = 1024; + unsigned int input_size = NUM_BLOCKS; + while (input_size > 1) { + unsigned int num_blocks = (input_size + BLOCK_REDUCTION_SIZE - 1) / BLOCK_REDUCTION_SIZE; + + // eval 0 + qm31 *reduction_output_0; + cudaMalloc((void **)&reduction_output_0, sizeof(qm31) * num_blocks); + reduction_kernel<<>>(eval_at_0_temp_d, input_size, reduction_output_0); + cudaFree(eval_at_0_temp_d); + eval_at_0_temp_d = reduction_output_0; + + // eval 2 + qm31 *reduction_output_2; + cudaMalloc((void **)&reduction_output_2, sizeof(qm31) * num_blocks); + reduction_kernel<<>>(eval_at_2_temp_d, input_size, reduction_output_2); + cudaFree(eval_at_2_temp_d); + eval_at_2_temp_d = reduction_output_2; + + input_size = num_blocks; + } + + cudaMemcpy(eval_at_0, eval_at_0_temp_d, sizeof(qm31), cudaMemcpyDeviceToDevice); + cudaMemcpy(eval_at_2, eval_at_2_temp_d, sizeof(qm31), cudaMemcpyDeviceToDevice); + + cudaFree(eval_at_0_temp_d); + cudaFree(eval_at_2_temp_d); +} + +__global__ void eval_logup_generic_sum_kernel( + qm31 *eq_evals, + qm31 *numerators, + qm31 *denominators, + uint32_t n_terms, + qm31 lambda, + qm31 *eval_at_0, + qm31 *eval_at_2 +) { + unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x; + + // todo: specify size of shared memory + __shared__ qm31 shared_eval_0[512]; + __shared__ qm31 shared_eval_2[512]; + + shared_eval_0[threadIdx.x] = {0, 0, 0, 0}; + shared_eval_2[threadIdx.x] = {0, 0, 0, 0}; + __syncthreads(); + + if (tid < n_terms) { + qm31 inp_numer_at_r0i0 = numerators[tid * 2]; + qm31 inp_denom_at_r0i0 = denominators[tid * 2]; + qm31 inp_numer_at_r0i1 = numerators[tid * 2 + 1]; + qm31 inp_denom_at_r0i1 = denominators[tid * 2 + 1]; + qm31 inp_numer_at_r1i0 = numerators[(n_terms + tid) * 2]; + qm31 inp_denom_at_r1i0 = denominators[(n_terms + tid) * 2]; + qm31 inp_numer_at_r1i1 = numerators[(n_terms + tid) * 2 + 1]; + qm31 inp_denom_at_r1i1 = denominators[(n_terms + tid) * 2 + 1]; + + qm31 inp_numer_at_r2i0 = sub(add(inp_numer_at_r1i0, inp_numer_at_r1i0), inp_numer_at_r0i0); + qm31 inp_denom_at_r2i0 = sub(add(inp_denom_at_r1i0, inp_denom_at_r1i0), inp_denom_at_r0i0); + qm31 inp_numer_at_r2i1 = sub(add(inp_numer_at_r1i1, inp_numer_at_r1i1), inp_numer_at_r0i1); + qm31 inp_denom_at_r2i1 = sub(add(inp_denom_at_r1i1, inp_denom_at_r1i1), inp_denom_at_r0i1); + + Fraction fraction_eval_0 = add_fraction(Fraction(inp_numer_at_r0i0, inp_denom_at_r0i0), Fraction(inp_numer_at_r0i1, inp_denom_at_r0i1)); + Fraction fraction_eval_2 = add_fraction(Fraction(inp_numer_at_r2i0, inp_denom_at_r2i0), Fraction(inp_numer_at_r2i1, inp_denom_at_r2i1)); + + shared_eval_0[threadIdx.x] = mul(eq_evals[tid], add(fraction_eval_0.numerator, mul(lambda, fraction_eval_0.denominator))); + shared_eval_2[threadIdx.x] = mul(eq_evals[tid], add(fraction_eval_2.numerator, mul(lambda, fraction_eval_2.denominator))); + } + __syncthreads(); + + // Perform intra-block reduction in shared memory + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + shared_eval_0[threadIdx.x] = add(shared_eval_0[threadIdx.x], shared_eval_0[threadIdx.x + s]); + shared_eval_2[threadIdx.x] = add(shared_eval_2[threadIdx.x], shared_eval_2[threadIdx.x + s]); + } + __syncthreads(); + } + + // Set global memory for each block id, grab the reduced shared thread id w.r.t. each block + if (threadIdx.x == 0) { + eval_at_0[blockIdx.x] = shared_eval_0[threadIdx.x]; + eval_at_2[blockIdx.x] = shared_eval_2[threadIdx.x]; + } + +} + +void eval_logup_generic_sum( + qm31 *eq_evals, + qm31 *numerators, + qm31 *denominators, + uint32_t n_terms, + qm31 lambda, + qm31 *eval_at_0, + qm31 *eval_at_2 +) { + const unsigned int BLOCK_SIZE = 512; + const unsigned int NUM_BLOCKS = (n_terms + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // Arrays for intra-block reduction + qm31 *eval_at_0_temp_d; + qm31 *eval_at_2_temp_d; + + cudaMalloc((void **)&eval_at_0_temp_d, sizeof(qm31) * NUM_BLOCKS); + cudaMalloc((void **)&eval_at_2_temp_d, sizeof(qm31) * NUM_BLOCKS); + + eval_logup_generic_sum_kernel<<>>( + eq_evals, + numerators, + denominators, + n_terms, + lambda, + eval_at_0_temp_d, + eval_at_2_temp_d + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + fprintf(stderr, "eval_logup_generic_sum_kernel launch error: %s\n", cudaGetErrorString(err)); + } + + // Synchronize to catch any runtime errors + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "eval_logup_generic_sum_kernel execution error: %s\n", cudaGetErrorString(err)); + } + + // Post intra-block reduction + const int BLOCK_REDUCTION_SIZE = 1024; + unsigned int input_size = NUM_BLOCKS; + while (input_size > 1) { + unsigned int num_blocks = (input_size + BLOCK_REDUCTION_SIZE - 1) / BLOCK_REDUCTION_SIZE; + + // eval 0 + qm31 *reduction_output_0; + cudaMalloc((void **)&reduction_output_0, sizeof(qm31) * num_blocks); + reduction_kernel<<>>(eval_at_0_temp_d, input_size, reduction_output_0); + cudaFree(eval_at_0_temp_d); + eval_at_0_temp_d = reduction_output_0; + + // eval 2 + qm31 *reduction_output_2; + cudaMalloc((void **)&reduction_output_2, sizeof(qm31) * num_blocks); + reduction_kernel<<>>(eval_at_2_temp_d, input_size, reduction_output_2); + cudaFree(eval_at_2_temp_d); + eval_at_2_temp_d = reduction_output_2; + + input_size = num_blocks; + } + + cudaMemcpy(eval_at_0, eval_at_0_temp_d, sizeof(qm31), cudaMemcpyDeviceToDevice); + cudaMemcpy(eval_at_2, eval_at_2_temp_d, sizeof(qm31), cudaMemcpyDeviceToDevice); + + cudaFree(eval_at_0_temp_d); + cudaFree(eval_at_2_temp_d); +} + +__global__ void eval_logup_multiplicities_sum_kernel( + qm31 *eq_evals, + m31 *numerators, + qm31 *denominators, + uint32_t n_terms, + qm31 lambda, + qm31 *eval_at_0, + qm31 *eval_at_2 +) { + unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x; + + // todo: specify size of shared memory + __shared__ qm31 shared_eval_0[512]; + __shared__ qm31 shared_eval_2[512]; + + shared_eval_0[threadIdx.x] = {0, 0, 0, 0}; + shared_eval_2[threadIdx.x] = {0, 0, 0, 0}; + __syncthreads(); + + if (tid < n_terms) { + m31 inp_numer_at_r0i0 = numerators[tid * 2]; + qm31 inp_denom_at_r0i0 = denominators[tid * 2]; + m31 inp_numer_at_r0i1 = numerators[tid * 2 + 1]; + qm31 inp_denom_at_r0i1 = denominators[tid * 2 + 1]; + m31 inp_numer_at_r1i0 = numerators[(n_terms + tid) * 2]; + qm31 inp_denom_at_r1i0 = denominators[(n_terms + tid) * 2]; + m31 inp_numer_at_r1i1 = numerators[(n_terms + tid) * 2 + 1]; + qm31 inp_denom_at_r1i1 = denominators[(n_terms + tid) * 2 + 1]; + + m31 inp_numer_at_r2i0 = sub(add(inp_numer_at_r1i0, inp_numer_at_r1i0), inp_numer_at_r0i0); + qm31 inp_denom_at_r2i0 = sub(add(inp_denom_at_r1i0, inp_denom_at_r1i0), inp_denom_at_r0i0); + m31 inp_numer_at_r2i1 = sub(add(inp_numer_at_r1i1, inp_numer_at_r1i1), inp_numer_at_r0i1); + qm31 inp_denom_at_r2i1 = sub(add(inp_denom_at_r1i1, inp_denom_at_r1i1), inp_denom_at_r0i1); + + Fraction fraction_eval_0 = add_fraction(Fraction(inp_numer_at_r0i0, inp_denom_at_r0i0), Fraction(inp_numer_at_r0i1, inp_denom_at_r0i1)); + Fraction fraction_eval_2 = add_fraction(Fraction(inp_numer_at_r2i0, inp_denom_at_r2i0), Fraction(inp_numer_at_r2i1, inp_denom_at_r2i1)); + + shared_eval_0[threadIdx.x] = mul(eq_evals[tid], add(fraction_eval_0.numerator, mul(lambda, fraction_eval_0.denominator))); + shared_eval_2[threadIdx.x] = mul(eq_evals[tid], add(fraction_eval_2.numerator, mul(lambda, fraction_eval_2.denominator))); + } + __syncthreads(); + + // Perform intra-block reduction in shared memory + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + shared_eval_0[threadIdx.x] = add(shared_eval_0[threadIdx.x], shared_eval_0[threadIdx.x + s]); + shared_eval_2[threadIdx.x] = add(shared_eval_2[threadIdx.x], shared_eval_2[threadIdx.x + s]); + } + __syncthreads(); + } + + // Set global memory for each block id, grab the reduced shared thread id w.r.t. each block + if (threadIdx.x == 0) { + eval_at_0[blockIdx.x] = shared_eval_0[threadIdx.x]; + eval_at_2[blockIdx.x] = shared_eval_2[threadIdx.x]; + } + +} + +void eval_logup_multiplicities_sum( + qm31 *eq_evals, + m31 *numerators, + qm31 *denominators, + uint32_t n_terms, + qm31 lambda, + qm31 *eval_at_0, + qm31 *eval_at_2 +) { + const unsigned int BLOCK_SIZE = 512; + const unsigned int NUM_BLOCKS = (n_terms + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // Arrays for intra-block reduction + qm31 *eval_at_0_temp_d; + qm31 *eval_at_2_temp_d; + + cudaMalloc((void **)&eval_at_0_temp_d, sizeof(qm31) * NUM_BLOCKS); + cudaMalloc((void **)&eval_at_2_temp_d, sizeof(qm31) * NUM_BLOCKS); + + eval_logup_multiplicities_sum_kernel<<>>( + eq_evals, + numerators, + denominators, + n_terms, + lambda, + eval_at_0_temp_d, + eval_at_2_temp_d + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + fprintf(stderr, "eval_logup_multiplicities_sum_kernel launch error: %s\n", cudaGetErrorString(err)); + } + + // Synchronize to catch any runtime errors + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "eval_logup_multiplicities_sum_kernel execution error: %s\n", cudaGetErrorString(err)); + } + + // Post intra-block reduction + const int BLOCK_REDUCTION_SIZE = 1024; + unsigned int input_size = NUM_BLOCKS; + while (input_size > 1) { + unsigned int num_blocks = (input_size + BLOCK_REDUCTION_SIZE - 1) / BLOCK_REDUCTION_SIZE; + + // eval 0 + qm31 *reduction_output_0; + cudaMalloc((void **)&reduction_output_0, sizeof(qm31) * num_blocks); + reduction_kernel<<>>(eval_at_0_temp_d, input_size, reduction_output_0); + cudaFree(eval_at_0_temp_d); + eval_at_0_temp_d = reduction_output_0; + + // eval 2 + qm31 *reduction_output_2; + cudaMalloc((void **)&reduction_output_2, sizeof(qm31) * num_blocks); + reduction_kernel<<>>(eval_at_2_temp_d, input_size, reduction_output_2); + cudaFree(eval_at_2_temp_d); + eval_at_2_temp_d = reduction_output_2; + + input_size = num_blocks; + } + + cudaMemcpy(eval_at_0, eval_at_0_temp_d, sizeof(qm31), cudaMemcpyDeviceToDevice); + cudaMemcpy(eval_at_2, eval_at_2_temp_d, sizeof(qm31), cudaMemcpyDeviceToDevice); + cudaFree(eval_at_0_temp_d); + cudaFree(eval_at_2_temp_d); +} + +__global__ void eval_logup_singles_sum_kernel( + qm31 *eq_evals, + qm31 *denominators, + uint32_t n_terms, + qm31 lambda, + qm31 *eval_at_0, + qm31 *eval_at_2 +) { + unsigned int tid = blockIdx.x * blockDim.x + threadIdx.x; + + // todo: specify size of shared memory + __shared__ qm31 shared_eval_0[512]; + __shared__ qm31 shared_eval_2[512]; + + shared_eval_0[threadIdx.x] = {0, 0, 0, 0}; + shared_eval_2[threadIdx.x] = {0, 0, 0, 0}; + __syncthreads(); + + if (tid < n_terms) { + qm31 inp_denom_at_r0i0 = denominators[tid * 2]; + qm31 inp_denom_at_r0i1 = denominators[tid * 2 + 1]; + qm31 inp_denom_at_r1i0 = denominators[(n_terms + tid) * 2]; + qm31 inp_denom_at_r1i1 = denominators[(n_terms + tid) * 2 + 1]; + + qm31 inp_denom_at_r2i0 = sub(add(inp_denom_at_r1i0, inp_denom_at_r1i0), inp_denom_at_r0i0); + qm31 inp_denom_at_r2i1 = sub(add(inp_denom_at_r1i1, inp_denom_at_r1i1), inp_denom_at_r0i1); + + Fraction fraction_eval_0 = add_reciprocal(Reciprocal(inp_denom_at_r0i0), Reciprocal(inp_denom_at_r0i1)); + Fraction fraction_eval_2 = add_reciprocal(Reciprocal(inp_denom_at_r2i0), Reciprocal(inp_denom_at_r2i1)); + + shared_eval_0[threadIdx.x] = mul(eq_evals[tid], add(fraction_eval_0.numerator, mul(lambda, fraction_eval_0.denominator))); + shared_eval_2[threadIdx.x] = mul(eq_evals[tid], add(fraction_eval_2.numerator, mul(lambda, fraction_eval_2.denominator))); + } + __syncthreads(); + + // Perform intra-block reduction in shared memory + for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + shared_eval_0[threadIdx.x] = add(shared_eval_0[threadIdx.x], shared_eval_0[threadIdx.x + s]); + shared_eval_2[threadIdx.x] = add(shared_eval_2[threadIdx.x], shared_eval_2[threadIdx.x + s]); + } + __syncthreads(); + } + + // Set global memory for each block id, grab the reduced shared thread id w.r.t. each block + if (threadIdx.x == 0) { + eval_at_0[blockIdx.x] = shared_eval_0[threadIdx.x]; + eval_at_2[blockIdx.x] = shared_eval_2[threadIdx.x]; + } + +} + +void eval_logup_singles_sum( + qm31 *eq_evals, + qm31 *denominators, + uint32_t n_terms, + qm31 lambda, + qm31 *eval_at_0, + qm31 *eval_at_2 +) { + const unsigned int BLOCK_SIZE = 512; + const unsigned int NUM_BLOCKS = (n_terms + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // Arrays for intra-block reduction + qm31 *eval_at_0_temp_d; + qm31 *eval_at_2_temp_d; + + cudaMalloc((void **)&eval_at_0_temp_d, sizeof(qm31) * NUM_BLOCKS); + cudaMalloc((void **)&eval_at_2_temp_d, sizeof(qm31) * NUM_BLOCKS); + + eval_logup_singles_sum_kernel<<>>( + eq_evals, + denominators, + n_terms, + lambda, + eval_at_0_temp_d, + eval_at_2_temp_d + ); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + fprintf(stderr, "eval_logup_singles_sum_kernel launch error: %s\n", cudaGetErrorString(err)); + } + + // Synchronize to catch any runtime errors + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + fprintf(stderr, "eval_logup_singles_sum_kernel execution error: %s\n", cudaGetErrorString(err)); + } + + // Post intra-block reduction + const int BLOCK_REDUCTION_SIZE = 1024; + unsigned int input_size = NUM_BLOCKS; + while (input_size > 1) { + unsigned int num_blocks = (input_size + BLOCK_REDUCTION_SIZE - 1) / BLOCK_REDUCTION_SIZE; + + // eval 0 + qm31 *reduction_output_0; + cudaMalloc((void **)&reduction_output_0, sizeof(qm31) * num_blocks); + reduction_kernel<<>>(eval_at_0_temp_d, input_size, reduction_output_0); + cudaFree(eval_at_0_temp_d); + eval_at_0_temp_d = reduction_output_0; + + // eval 2 + qm31 *reduction_output_2; + cudaMalloc((void **)&reduction_output_2, sizeof(qm31) * num_blocks); + reduction_kernel<<>>(eval_at_2_temp_d, input_size, reduction_output_2); + cudaFree(eval_at_2_temp_d); + eval_at_2_temp_d = reduction_output_2; + + input_size = num_blocks; + } + + cudaMemcpy(eval_at_0, eval_at_0_temp_d, sizeof(qm31), cudaMemcpyDeviceToDevice); + cudaMemcpy(eval_at_2, eval_at_2_temp_d, sizeof(qm31), cudaMemcpyDeviceToDevice); + cudaFree(eval_at_0_temp_d); + cudaFree(eval_at_2_temp_d); +} + diff --git a/stwo_gpu_backend/src/column.rs b/stwo_gpu_backend/src/column.rs index 2d3bd72..e071112 100644 --- a/stwo_gpu_backend/src/column.rs +++ b/stwo_gpu_backend/src/column.rs @@ -4,7 +4,7 @@ use stwo_prover::core::{ fields::{m31::BaseField, qm31::SecureField}, }; -use crate::cuda::{bindings, BaseFieldVec}; +use crate::cuda::{bindings, BaseFieldVec, SecureFieldVec}; use crate::{backend::CudaBackend, cuda}; impl ColumnOps for CudaBackend { @@ -83,8 +83,9 @@ impl Column for cuda::SecureFieldVec { self.size } - fn at(&self, _index: usize) -> SecureField { - todo!() + fn at(&self, index: usize) -> SecureField { + // TODO: call binding to return value directly + self.to_cpu()[index] } fn set(&mut self, _index: usize, _value: SecureField) { @@ -100,8 +101,9 @@ impl Column for cuda::SecureFieldVec { } impl FromIterator for cuda::SecureFieldVec { - fn from_iter>(_iter: T) -> Self { - todo!() + fn from_iter>(iter: T) -> Self { + let secure_field_vec: Vec = iter.into_iter().collect(); + SecureFieldVec::from_vec(secure_field_vec) } } diff --git a/stwo_gpu_backend/src/cuda/bindings.rs b/stwo_gpu_backend/src/cuda/bindings.rs index 9d5b096..7723875 100644 --- a/stwo_gpu_backend/src/cuda/bindings.rs +++ b/stwo_gpu_backend/src/cuda/bindings.rs @@ -5,6 +5,13 @@ use stwo_prover::core::{ fields::{m31::BaseField, qm31::SecureField}, }; +use super::{BaseFieldVec, SecureFieldVec}; + +#[repr(C)] +pub struct CudaBaseField { + a: BaseField, +} + #[repr(C)] pub struct CudaSecureField { a: BaseField, @@ -268,4 +275,75 @@ extern "C" { extended_domain_size: u32, number_of_columns: u32, ); + + pub fn next_grand_product_layer( + layer: *const CudaSecureField, + layer_size: usize, + next_layer: *const CudaSecureField, + next_layer_size: usize, + ); + + pub fn next_logup_generic_layer( + numerators: *const CudaSecureField, + denominators: *const CudaSecureField, + size: usize, + next_numerators: *const CudaSecureField, + next_denominators: *const CudaSecureField, + next_size: usize, + ); + + pub fn next_logup_multiplicities_layer( + numerators: *const CudaBaseField, + denominators: *const CudaSecureField, + size: usize, + next_numerators: *const CudaSecureField, + next_denominators: *const CudaSecureField, + next_size: usize, + ); + + pub fn next_logup_singles_layer( + denominators: *const CudaSecureField, + size: usize, + next_numerators: *const CudaSecureField, + next_denominators: *const CudaSecureField, + next_size: usize, + ); + + pub fn eval_grand_product_sum( + eq_evals: *const CudaSecureField, + input_layer: *const CudaSecureField, + n_terms: usize, + eval_at_0: *const CudaSecureField, + eval_at_2: *const CudaSecureField, + ); + + pub fn eval_logup_generic_sum( + eq_evals: *const CudaSecureField, + numerators: *const CudaSecureField, + denominators: *const CudaSecureField, + n_terms: usize, + lambda: CudaSecureField, + eval_at_0: *const CudaSecureField, + eval_at_2: *const CudaSecureField, + ); + + pub fn eval_logup_multiplicities_sum( + eq_evals: *const CudaSecureField, + numerators: *const CudaBaseField, + denominators: *const CudaSecureField, + n_terms: usize, + lambda: CudaSecureField, + eval_at_0: *const CudaSecureField, + eval_at_2: *const CudaSecureField, + ); + + pub fn eval_logup_singles_sum( + eq_evals: *const CudaSecureField, + denominators: *const CudaSecureField, + n_terms: usize, + lambda: CudaSecureField, + eval_at_0: *const CudaSecureField, + eval_at_2: *const CudaSecureField, + ); + } diff --git a/stwo_gpu_backend/src/lookups/gkr.rs b/stwo_gpu_backend/src/lookups/gkr.rs index 1e71156..5a56d17 100644 --- a/stwo_gpu_backend/src/lookups/gkr.rs +++ b/stwo_gpu_backend/src/lookups/gkr.rs @@ -1,13 +1,23 @@ use crate::CudaBackend; +use crate::cuda::{bindings, BaseFieldVec, SecureFieldVec}; +use crate::cuda::bindings::{CudaBaseField, CudaSecureField}; + +use num_traits::Zero; + use stwo_prover::core::{ - fields::{m31::BaseField, qm31::SecureField}, + backend::Column, + fields::{ + m31::BaseField, + qm31::SecureField, + }, lookups::{ - gkr_prover::GkrOps, + gkr_prover::{correct_sum_as_poly_in_first_variable, EqEvals, GkrOps, Layer}, mle::{Mle, MleOps}, - }, -}; -use crate::cuda::{bindings, SecureFieldVec}; -use crate::cuda::bindings::CudaSecureField; + sumcheck::MultivariatePolyOracle, + },}; + + +const SINGLE_EVALUATION_SIZE: usize = 1; impl GkrOps for CudaBackend { fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle { @@ -30,24 +40,245 @@ impl GkrOps for CudaBackend { fn next_layer( layer: &stwo_prover::core::lookups::gkr_prover::Layer, ) -> stwo_prover::core::lookups::gkr_prover::Layer { - todo!() + match layer { + Layer::GrandProduct(col) => + next_grand_product_layer(col), + Layer::LogUpGeneric {numerators, denominators} => + next_logup_generic_layer::(numerators, denominators), + Layer::LogUpMultiplicities {numerators,denominators} => + next_logup_multiplicities_layer::(numerators, denominators), + Layer::LogUpSingles { denominators } => + next_logup_singles_layer::(denominators), + } } fn sum_as_poly_in_first_variable( h: &stwo_prover::core::lookups::gkr_prover::GkrMultivariatePolyOracle<'_, Self>, claim: SecureField, ) -> stwo_prover::core::lookups::utils::UnivariatePoly { - todo!() + let n_variables = h.n_variables(); + assert!(!n_variables.is_zero()); + let n_terms = 1 << (n_variables - 1); + let eq_evals = h.eq_evals.as_ref(); + // Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube. + let y = eq_evals.y(); + let lambda = h.lambda; + + let (mut eval_at_0, mut eval_at_2) = match &h.input_layer { + Layer::GrandProduct(col) => + eval_grand_product_sum(eq_evals, col, n_terms), + Layer::LogUpGeneric {numerators, denominators} => + eval_logup_generic_sum(eq_evals, numerators, denominators, n_terms, lambda), + Layer::LogUpMultiplicities {numerators, denominators} => + eval_logup_multiplicities_sum(eq_evals, numerators, denominators, n_terms, lambda), + Layer::LogUpSingles { denominators } => + eval_logup_singles_sum(eq_evals, denominators, n_terms, lambda) + }; + + // TODO(Daniel): Create polynomial interpolation kernel + eval_at_0 *= h.eq_fixed_var_correction; + eval_at_2 *= h.eq_fixed_var_correction; + correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables) + } +} + +fn eval_grand_product_sum( + eq_evals: &EqEvals, + input_layer: &Mle, + n_terms: usize, +) -> (SecureField, SecureField) { + let eval_at_0 = SecureFieldVec::new_uninitialized(SINGLE_EVALUATION_SIZE); + let eval_at_2 = SecureFieldVec::new_uninitialized(SINGLE_EVALUATION_SIZE); + + unsafe { + bindings::eval_grand_product_sum( + eq_evals.device_ptr as *const CudaSecureField, + input_layer.device_ptr as *const CudaSecureField, + n_terms, + eval_at_0.device_ptr as *const CudaSecureField, + eval_at_2.device_ptr as *const CudaSecureField); + }; + + (eval_at_0.to_cpu()[0].clone(), eval_at_2.to_cpu()[0].clone()) +} + +fn eval_logup_generic_sum( + eq_evals: &EqEvals, + numerators: &Mle, + denominators: &Mle, + n_terms: usize, + lambda: SecureField, +) -> (SecureField, SecureField) { + let eval_at_0 = SecureFieldVec::new_uninitialized(SINGLE_EVALUATION_SIZE); + let eval_at_2 = SecureFieldVec::new_uninitialized(SINGLE_EVALUATION_SIZE); + + unsafe { + bindings::eval_logup_generic_sum( + eq_evals.device_ptr as *const CudaSecureField, + numerators.device_ptr as *const CudaSecureField, + denominators.device_ptr as *const CudaSecureField, + n_terms, + CudaSecureField::from(lambda), + eval_at_0.device_ptr as *const CudaSecureField, + eval_at_2.device_ptr as *const CudaSecureField); + }; + + (eval_at_0.to_cpu()[0].clone(), eval_at_2.to_cpu()[0].clone()) +} + +fn eval_logup_multiplicities_sum( + eq_evals: &EqEvals, + numerators: &Mle, + denominators: &Mle, + n_terms: usize, + lambda: SecureField, +) -> (SecureField, SecureField) { + let eval_at_0 = SecureFieldVec::new_uninitialized(SINGLE_EVALUATION_SIZE); + let eval_at_2 = SecureFieldVec::new_uninitialized(SINGLE_EVALUATION_SIZE); + + unsafe { + bindings::eval_logup_multiplicities_sum( + eq_evals.device_ptr as *const CudaSecureField, + numerators.device_ptr as *const CudaBaseField, + denominators.device_ptr as *const CudaSecureField, + n_terms, + CudaSecureField::from(lambda), + eval_at_0.device_ptr as *const CudaSecureField, + eval_at_2.device_ptr as *const CudaSecureField); + }; + + (eval_at_0.to_cpu()[0].clone(), eval_at_2.to_cpu()[0].clone()) +} + +fn eval_logup_singles_sum( + eq_evals: &EqEvals, + denominators: &Mle, + n_terms: usize, + lambda: SecureField, +) -> (SecureField, SecureField) { + let eval_at_0 = SecureFieldVec::new_uninitialized(SINGLE_EVALUATION_SIZE); + let eval_at_2 = SecureFieldVec::new_uninitialized(SINGLE_EVALUATION_SIZE); + + unsafe { + bindings::eval_logup_singles_sum( + eq_evals.device_ptr as *const CudaSecureField, + denominators.device_ptr as *const CudaSecureField, + n_terms, + CudaSecureField::from(lambda), + eval_at_0.device_ptr as *const CudaSecureField, + eval_at_2.device_ptr as *const CudaSecureField); + }; + + (eval_at_0.to_cpu()[0].clone(), eval_at_2.to_cpu()[0].clone()) +} + +fn next_grand_product_layer(layer: &Mle) -> Layer { + let next_layer_size = layer.size / 2; + let next_layer = SecureFieldVec::new_uninitialized(next_layer_size); + + unsafe { + bindings::next_grand_product_layer( + layer.device_ptr as *const CudaSecureField, + layer.size, + next_layer.device_ptr as *const CudaSecureField, + next_layer_size + ); + }; + + Layer::GrandProduct(Mle::new(next_layer)) +} + +fn next_logup_generic_layer( + numerators: &Mle, + denominators: &Mle, +) -> Layer { + let next_layer_len = denominators.len() / 2; + let next_numerators = SecureFieldVec::new_uninitialized(next_layer_len); + let next_denominators = SecureFieldVec::new_uninitialized(next_layer_len); + + unsafe { + bindings::next_logup_generic_layer( + numerators.device_ptr as *const CudaSecureField, + denominators.device_ptr as *const CudaSecureField, + denominators.size, + next_numerators.device_ptr as *const CudaSecureField, + next_denominators.device_ptr as *const CudaSecureField, + next_layer_len + ); + }; + + Layer::LogUpGeneric { + numerators: Mle::new(next_numerators), + denominators: Mle::new(next_denominators), + } +} + +fn next_logup_multiplicities_layer( + numerators: &Mle, + denominators: &Mle, +) -> Layer { + let next_layer_len = denominators.len() / 2; + let next_numerators = SecureFieldVec::new_uninitialized(next_layer_len); + let next_denominators = SecureFieldVec::new_uninitialized(next_layer_len); + + unsafe { + bindings::next_logup_multiplicities_layer( + numerators.device_ptr as *const CudaBaseField, + denominators.device_ptr as *const CudaSecureField, + denominators.size, + next_numerators.device_ptr as *const CudaSecureField, + next_denominators.device_ptr as *const CudaSecureField, + next_layer_len + ); + }; + + Layer::LogUpGeneric { + numerators: Mle::new(next_numerators), + denominators: Mle::new(next_denominators), + } +} + +fn next_logup_singles_layer( + denominators: &Mle, +) -> Layer { + let next_layer_len = denominators.len() / 2; + let next_numerators = SecureFieldVec::new_uninitialized(next_layer_len); + let next_denominators = SecureFieldVec::new_uninitialized(next_layer_len); + + unsafe { + bindings::next_logup_singles_layer( + denominators.device_ptr as *const CudaSecureField, + denominators.size, + next_numerators.device_ptr as *const CudaSecureField, + next_denominators.device_ptr as *const CudaSecureField, + next_layer_len + ); + }; + + Layer::LogUpGeneric { + numerators: Mle::new(next_numerators), + denominators: Mle::new(next_denominators), } } mod tests { + use std::iter::zip; + use itertools::Itertools; - use crate::CudaBackend; + use num_traits::One; + use rand::{Rng, SeedableRng}; + use rand::rngs::SmallRng; + use stwo_prover::core::backend::{Column, CpuBackend}; - use stwo_prover::core::fields::m31::{BaseField, M31}; - use stwo_prover::core::fields::qm31::SecureField; - use stwo_prover::core::lookups::gkr_prover::GkrOps; + use stwo_prover::core::channel::{Blake2sChannel, Channel}; + use stwo_prover::core::fields::{ExtensionOf, Field}; + use stwo_prover::core::fields::{m31::{BaseField, M31}, qm31::SecureField}; + use stwo_prover::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer}; + use stwo_prover::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError}; + use stwo_prover::core::lookups::mle::{Mle, MleOps}; + use stwo_prover::core::lookups::utils::Fraction; + + use crate::CudaBackend; #[test] fn gen_eq_evals_matches_cpu() { @@ -63,4 +294,180 @@ mod tests { assert_eq!(gpu_eq_evals.to_cpu(), *cpu_eq_evals); } + + #[test] + fn grand_product_works() { + const N: usize = 1 << 18; + let values = Blake2sChannel::default().draw_felts(N); + let product = values.iter().product(); + + let col_gpu = Mle::::new(values.clone().into_iter().collect()); + let col_cpu = Mle::::new(values.into_iter().collect()); + + let input_layer = Layer::GrandProduct(col_gpu.clone()); + + let (proof, _) = prove_batch(&mut Blake2sChannel::default(), vec![input_layer]); + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut Blake2sChannel::default()).unwrap(); + + assert_eq!(proof.output_claims_by_instance, [vec![product]]); + assert_eq!( + claims_to_verify_by_instance, + [vec![eval_at_point(&col_cpu, &ood_point)]] + ); + } + + #[test] + fn logup_with_generic_trace_works() { + const N: usize = 1 << 18; + let mut rng = SmallRng::seed_from_u64(0); + let numerator_values = (0..N).map(|_| rng.gen()).collect::>(); + let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); + let sum = zip(&numerator_values, &denominator_values) + .map(|(&n, &d)| Fraction::new(n, d)) + .sum::>(); + let numerators = Mle::::new(numerator_values.clone().into_iter().collect()); + let denominators = Mle::::new(denominator_values.clone().into_iter().collect()); + let numerators_cpu = Mle::::new(numerator_values.into_iter().collect()); + let denominators_cpu = Mle::::new(denominator_values.into_iter().collect()); + + let top_layer = Layer::LogUpGeneric { + numerators: numerators.clone(), + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut Blake2sChannel::default(), vec![top_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut Blake2sChannel::default()).unwrap(); + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [ + eval_at_point(&numerators_cpu, &ood_point), + eval_at_point(&denominators_cpu, &ood_point) + ] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + } + + #[test] + fn logup_with_multiplicities_trace_works() { + const N: usize = 1 << 18; + let mut rng = SmallRng::seed_from_u64(0); + let numerator_values = (0..N).map(|_| rng.gen()).collect::>(); + let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); + let sum = zip(&numerator_values, &denominator_values) + .map(|(&n, &d)| Fraction::new(n.into(), d)) + .sum::>(); + let numerators = Mle::::new(numerator_values.clone().into_iter().collect()); + let denominators = Mle::::new(denominator_values.clone().into_iter().collect()); + let numerators_cpu = Mle::::new(numerator_values.into_iter().collect()); + let denominators_cpu = Mle::::new(denominator_values.into_iter().collect()); + + let top_layer = Layer::LogUpMultiplicities { + numerators: numerators.clone(), + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut Blake2sChannel::default(), vec![top_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut Blake2sChannel::default()).unwrap(); + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [ + eval_at_point(&numerators_cpu.into(), &ood_point), + eval_at_point(&denominators_cpu, &ood_point) + ] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + } + + #[test] + fn logup_with_singles_trace_works() { + const N: usize = 1 << 18; + let mut rng = SmallRng::seed_from_u64(0); + let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); + let sum = denominator_values + .iter() + .map(|&d| Fraction::new(SecureField::one(), d)) + .sum::>(); + let denominators = Mle::::new(denominator_values.clone().into_iter().collect()); + let denominators_cpu = Mle::::new(denominator_values.into_iter().collect()); + + let top_layer = Layer::LogUpSingles { + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut Blake2sChannel::default(), vec![top_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut Blake2sChannel::default()).unwrap(); + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [ + SecureField::one(), + eval_at_point(&denominators_cpu, &ood_point) + ] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + } + + // CPU evaluation helper function + pub(crate) fn eval_at_point(input: &Mle, point: &[SecureField]) -> SecureField + where + F: Field, + SecureField: ExtensionOf, + B: MleOps, + { + pub fn eval(mle_evals: &[SecureField], p: &[SecureField]) -> SecureField { + match p { + [] => mle_evals[0], + &[p_i, ref p @ ..] => { + let (lhs, rhs) = mle_evals.split_at(mle_evals.len() / 2); + let lhs_eval = eval(lhs, p); + let rhs_eval = eval(rhs, p); + // Equivalent to `eq(0, p_i) * lhs_eval + eq(1, p_i) * rhs_eval`. + p_i * (rhs_eval - lhs_eval) + lhs_eval + } + } + } + + let mle_evals = input + .clone() + .into_evals() + .to_cpu() + .into_iter() + .map(|v| v.into()) + .collect::>(); + + eval(&mle_evals, point) + } } \ No newline at end of file