From a26912dd5391e71dd51915f7e03f5527a698eb3c Mon Sep 17 00:00:00 2001 From: Sachin Mehta Date: Mon, 21 Oct 2024 14:52:34 +0530 Subject: [PATCH 1/5] Adding tests for blocking stub generation. --- .../golden/BlockingBidiStreamingService.kt | 166 +++++++++++++++++ .../golden/BlockingClientStreamingService.kt | 171 ++++++++++++++++++ .../golden/BlockingServerStreamingService.kt | 163 +++++++++++++++++ .../wire/kotlin/grpcserver/StubTest.kt | 81 +++++++++ 4 files changed, 581 insertions(+) create mode 100644 server-generator/src/test/golden/BlockingBidiStreamingService.kt create mode 100644 server-generator/src/test/golden/BlockingClientStreamingService.kt create mode 100644 server-generator/src/test/golden/BlockingServerStreamingService.kt diff --git a/server-generator/src/test/golden/BlockingBidiStreamingService.kt b/server-generator/src/test/golden/BlockingBidiStreamingService.kt new file mode 100644 index 0000000..ea66ffd --- /dev/null +++ b/server-generator/src/test/golden/BlockingBidiStreamingService.kt @@ -0,0 +1,166 @@ +// Code generated by Wire protocol buffer compiler, do not edit. +package test + +import com.google.protobuf.DescriptorProtos +import com.google.protobuf.Descriptors +import com.squareup.wire.kotlin.grpcserver.MessageSinkAdapter +import com.squareup.wire.kotlin.grpcserver.MessageSourceAdapter +import com.squareup.wire.kotlin.grpcserver.WireBindableService +import com.squareup.wire.kotlin.grpcserver.WireMethodMarshaller +import io.grpc.CallOptions +import io.grpc.Channel +import io.grpc.MethodDescriptor +import io.grpc.ServerServiceDefinition +import io.grpc.ServiceDescriptor +import io.grpc.ServiceDescriptor.newBuilder +import io.grpc.stub.AbstractStub +import io.grpc.stub.StreamObserver +import java.io.InputStream +import java.lang.Class +import java.lang.UnsupportedOperationException +import java.util.concurrent.ExecutorService +import kotlin.Array +import kotlin.String +import kotlin.collections.Map +import kotlin.collections.Set +import kotlin.jvm.Volatile +import io.grpc.stub.ClientCalls.asyncBidiStreamingCall as clientCallsAsyncBidiStreamingCall +import io.grpc.stub.ServerCalls.asyncBidiStreamingCall as serverCallsAsyncBidiStreamingCall + +public object TestServiceWireGrpc { + public const val SERVICE_NAME: String = "test.TestService" + + @Volatile + private var serviceDescriptor: ServiceDescriptor? = null + + private val descriptorMap: Map = + createDescriptorMap0() + + + @Volatile + private var getTestRPCMethod: MethodDescriptor? = null + + private fun descriptorFor(`data`: Array): DescriptorProtos.FileDescriptorProto { + val str = data.fold(java.lang.StringBuilder()) { b, s -> b.append(s) }.toString() + val bytes = java.util.Base64.getDecoder().decode(str) + return DescriptorProtos.FileDescriptorProto.parseFrom(bytes) + } + + private fun fileDescriptor(path: String, visited: Set): Descriptors.FileDescriptor { + val proto = descriptorMap[path]!! + val deps = proto.dependencyList.filter { !visited.contains(it) }.map { fileDescriptor(it, + visited + path) } + return Descriptors.FileDescriptor.buildFrom(proto, deps.toTypedArray()) + } + + private fun createDescriptorMap0(): Map { + val subMap = mapOf( + "service.proto" to descriptorFor(arrayOf( + "Cg1zZXJ2aWNlLnByb3RvEgR0ZXN0IgYKBFRlc3QyNAoLVGVzdFNlcnZpY2USJQoHVGVzdFJQQxIKLnRl", + "c3QuVGVzdBoKLnRlc3QuVGVzdCgBMAE=", + )), + ) + return subMap + } + + public fun getServiceDescriptor(): ServiceDescriptor? { + var result = serviceDescriptor + if (result == null) { + synchronized(TestServiceWireGrpc::class) { + result = serviceDescriptor + if (result == null) { + result = newBuilder(SERVICE_NAME) + .addMethod(getTestRPCMethod()) + .setSchemaDescriptor(io.grpc.protobuf.ProtoFileDescriptorSupplier { + fileDescriptor("service.proto", emptySet()) + }) + .build() + serviceDescriptor = result + } + } + } + return result + } + + public fun getTestRPCMethod(): MethodDescriptor { + var result: MethodDescriptor? = getTestRPCMethod + if (result == null) { + synchronized(TestServiceWireGrpc::class) { + result = getTestRPCMethod + if (result == null) { + getTestRPCMethod = MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.BIDI_STREAMING) + .setFullMethodName( + MethodDescriptor.generateFullMethodName( + "test.TestService", "TestRPC" + ) + ) + .setSampledToLocalTracing(true) + .setRequestMarshaller(TestServiceImplBase.TestMarshaller()) + .setResponseMarshaller(TestServiceImplBase.TestMarshaller()) + .build() + } + } + } + return getTestRPCMethod!! + } + + public fun newStub(channel: Channel): TestServiceStub = TestServiceStub(channel) + + public fun newBlockingStub(channel: Channel): TestServiceBlockingStub = + TestServiceBlockingStub(channel) + + public abstract class TestServiceImplBase : WireBindableService { + public open fun TestRPC(response: StreamObserver): StreamObserver = throw + UnsupportedOperationException() + + override fun bindService(): ServerServiceDefinition = + ServerServiceDefinition.builder(getServiceDescriptor()).addMethod( + getTestRPCMethod(), + serverCallsAsyncBidiStreamingCall(this@TestServiceImplBase::TestRPC) + ).build() + + public class TestMarshaller : WireMethodMarshaller { + override fun stream(`value`: Test): InputStream = Test.ADAPTER.encode(value).inputStream() + + override fun marshalledClass(): Class = Test::class.java + + override fun parse(stream: InputStream): Test = Test.ADAPTER.decode(stream) + } + } + + public class BindableAdapter( + private val streamExecutor: ExecutorService, + private val service: () -> TestServiceBlockingServer, + ) : TestServiceImplBase() { + override fun TestRPC(response: StreamObserver): StreamObserver { + val requestStream = MessageSourceAdapter() + streamExecutor.submit { + service().TestRPC(requestStream, MessageSinkAdapter(response)) + } + return requestStream + } + } + + public class TestServiceStub : AbstractStub { + internal constructor(channel: Channel) : super(channel) + + internal constructor(channel: Channel, callOptions: CallOptions) : super(channel, callOptions) + + override fun build(channel: Channel, callOptions: CallOptions): TestServiceStub = + TestServiceStub(channel, callOptions) + + public fun TestRPC(response: StreamObserver): StreamObserver = + clientCallsAsyncBidiStreamingCall(channel.newCall(getTestRPCMethod(), callOptions), + response) + } + + public class TestServiceBlockingStub : AbstractStub { + internal constructor(channel: Channel) : super(channel) + + internal constructor(channel: Channel, callOptions: CallOptions) : super(channel, callOptions) + + override fun build(channel: Channel, callOptions: CallOptions): TestServiceBlockingStub = + TestServiceBlockingStub(channel, callOptions) + } +} diff --git a/server-generator/src/test/golden/BlockingClientStreamingService.kt b/server-generator/src/test/golden/BlockingClientStreamingService.kt new file mode 100644 index 0000000..21f8fd0 --- /dev/null +++ b/server-generator/src/test/golden/BlockingClientStreamingService.kt @@ -0,0 +1,171 @@ +// Code generated by Wire protocol buffer compiler, do not edit. +package test + +import com.google.protobuf.DescriptorProtos +import com.google.protobuf.Descriptors +import com.squareup.wire.kotlin.grpcserver.MessageSourceAdapter +import com.squareup.wire.kotlin.grpcserver.WireBindableService +import com.squareup.wire.kotlin.grpcserver.WireMethodMarshaller +import io.grpc.CallOptions +import io.grpc.Channel +import io.grpc.MethodDescriptor +import io.grpc.ServerServiceDefinition +import io.grpc.ServiceDescriptor +import io.grpc.ServiceDescriptor.newBuilder +import io.grpc.stub.AbstractStub +import io.grpc.stub.ClientCalls.blockingServerStreamingCall +import io.grpc.stub.StreamObserver +import java.io.InputStream +import java.lang.Class +import java.lang.UnsupportedOperationException +import java.util.concurrent.ExecutorService +import kotlin.Array +import kotlin.String +import kotlin.collections.Iterator +import kotlin.collections.Map +import kotlin.collections.Set +import kotlin.jvm.Volatile +import io.grpc.stub.ClientCalls.asyncClientStreamingCall as clientCallsAsyncClientStreamingCall +import io.grpc.stub.ServerCalls.asyncClientStreamingCall as serverCallsAsyncClientStreamingCall + +public object TestServiceWireGrpc { + public const val SERVICE_NAME: String = "test.TestService" + + @Volatile + private var serviceDescriptor: ServiceDescriptor? = null + + private val descriptorMap: Map = + createDescriptorMap0() + + + @Volatile + private var getTestRPCMethod: MethodDescriptor? = null + + private fun descriptorFor(`data`: Array): DescriptorProtos.FileDescriptorProto { + val str = data.fold(java.lang.StringBuilder()) { b, s -> b.append(s) }.toString() + val bytes = java.util.Base64.getDecoder().decode(str) + return DescriptorProtos.FileDescriptorProto.parseFrom(bytes) + } + + private fun fileDescriptor(path: String, visited: Set): Descriptors.FileDescriptor { + val proto = descriptorMap[path]!! + val deps = proto.dependencyList.filter { !visited.contains(it) }.map { fileDescriptor(it, + visited + path) } + return Descriptors.FileDescriptor.buildFrom(proto, deps.toTypedArray()) + } + + private fun createDescriptorMap0(): Map { + val subMap = mapOf( + "service.proto" to descriptorFor(arrayOf( + "Cg1zZXJ2aWNlLnByb3RvEgR0ZXN0IgYKBFRlc3QyMgoLVGVzdFNlcnZpY2USIwoHVGVzdFJQQxIKLnRl", + "c3QuVGVzdBoKLnRlc3QuVGVzdCgB", + )), + ) + return subMap + } + + public fun getServiceDescriptor(): ServiceDescriptor? { + var result = serviceDescriptor + if (result == null) { + synchronized(TestServiceWireGrpc::class) { + result = serviceDescriptor + if (result == null) { + result = newBuilder(SERVICE_NAME) + .addMethod(getTestRPCMethod()) + .setSchemaDescriptor(io.grpc.protobuf.ProtoFileDescriptorSupplier { + fileDescriptor("service.proto", emptySet()) + }) + .build() + serviceDescriptor = result + } + } + } + return result + } + + public fun getTestRPCMethod(): MethodDescriptor { + var result: MethodDescriptor? = getTestRPCMethod + if (result == null) { + synchronized(TestServiceWireGrpc::class) { + result = getTestRPCMethod + if (result == null) { + getTestRPCMethod = MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.CLIENT_STREAMING) + .setFullMethodName( + MethodDescriptor.generateFullMethodName( + "test.TestService", "TestRPC" + ) + ) + .setSampledToLocalTracing(true) + .setRequestMarshaller(TestServiceImplBase.TestMarshaller()) + .setResponseMarshaller(TestServiceImplBase.TestMarshaller()) + .build() + } + } + } + return getTestRPCMethod!! + } + + public fun newStub(channel: Channel): TestServiceStub = TestServiceStub(channel) + + public fun newBlockingStub(channel: Channel): TestServiceBlockingStub = + TestServiceBlockingStub(channel) + + public abstract class TestServiceImplBase : WireBindableService { + public open fun TestRPC(response: StreamObserver): StreamObserver = throw + UnsupportedOperationException() + + override fun bindService(): ServerServiceDefinition = + ServerServiceDefinition.builder(getServiceDescriptor()).addMethod( + getTestRPCMethod(), + serverCallsAsyncClientStreamingCall(this@TestServiceImplBase::TestRPC) + ).build() + + public class TestMarshaller : WireMethodMarshaller { + override fun stream(`value`: Test): InputStream = Test.ADAPTER.encode(value).inputStream() + + override fun marshalledClass(): Class = Test::class.java + + override fun parse(stream: InputStream): Test = Test.ADAPTER.decode(stream) + } + } + + public class BindableAdapter( + private val streamExecutor: ExecutorService, + private val service: () -> TestServiceBlockingServer, + ) : TestServiceImplBase() { + override fun TestRPC(response: StreamObserver): StreamObserver { + val requestStream = MessageSourceAdapter() + streamExecutor.submit { + response.onNext(service().TestRPC(requestStream)) + response.onCompleted() + } + return requestStream + } + } + + public class TestServiceStub : AbstractStub { + internal constructor(channel: Channel) : super(channel) + + internal constructor(channel: Channel, callOptions: CallOptions) : super(channel, callOptions) + + override fun build(channel: Channel, callOptions: CallOptions): TestServiceStub = + TestServiceStub(channel, callOptions) + + public fun TestRPC(response: StreamObserver): StreamObserver = + clientCallsAsyncClientStreamingCall(channel.newCall(getTestRPCMethod(), callOptions), + response) + } + + public class TestServiceBlockingStub : AbstractStub { + internal constructor(channel: Channel) : super(channel) + + internal constructor(channel: Channel, callOptions: CallOptions) : super(channel, callOptions) + + override fun build(channel: Channel, callOptions: CallOptions): TestServiceBlockingStub = + TestServiceBlockingStub(channel, callOptions) + + public fun TestRPC(request: Test): Iterator = blockingServerStreamingCall(channel, + getTestRPCMethod(), callOptions, request) + } +} diff --git a/server-generator/src/test/golden/BlockingServerStreamingService.kt b/server-generator/src/test/golden/BlockingServerStreamingService.kt new file mode 100644 index 0000000..e3fcea0 --- /dev/null +++ b/server-generator/src/test/golden/BlockingServerStreamingService.kt @@ -0,0 +1,163 @@ +// Code generated by Wire protocol buffer compiler, do not edit. +package test + +import com.google.protobuf.DescriptorProtos +import com.google.protobuf.Descriptors +import com.squareup.wire.kotlin.grpcserver.MessageSinkAdapter +import com.squareup.wire.kotlin.grpcserver.WireBindableService +import com.squareup.wire.kotlin.grpcserver.WireMethodMarshaller +import io.grpc.CallOptions +import io.grpc.Channel +import io.grpc.MethodDescriptor +import io.grpc.ServerServiceDefinition +import io.grpc.ServiceDescriptor +import io.grpc.ServiceDescriptor.newBuilder +import io.grpc.stub.AbstractStub +import io.grpc.stub.StreamObserver +import java.io.InputStream +import java.lang.Class +import java.lang.UnsupportedOperationException +import java.util.concurrent.ExecutorService +import kotlin.Array +import kotlin.String +import kotlin.Unit +import kotlin.collections.Map +import kotlin.collections.Set +import kotlin.jvm.Volatile +import io.grpc.stub.ClientCalls.asyncServerStreamingCall as clientCallsAsyncServerStreamingCall +import io.grpc.stub.ServerCalls.asyncServerStreamingCall as serverCallsAsyncServerStreamingCall + +public object TestServiceWireGrpc { + public const val SERVICE_NAME: String = "test.TestService" + + @Volatile + private var serviceDescriptor: ServiceDescriptor? = null + + private val descriptorMap: Map = + createDescriptorMap0() + + + @Volatile + private var getTestRPCMethod: MethodDescriptor? = null + + private fun descriptorFor(`data`: Array): DescriptorProtos.FileDescriptorProto { + val str = data.fold(java.lang.StringBuilder()) { b, s -> b.append(s) }.toString() + val bytes = java.util.Base64.getDecoder().decode(str) + return DescriptorProtos.FileDescriptorProto.parseFrom(bytes) + } + + private fun fileDescriptor(path: String, visited: Set): Descriptors.FileDescriptor { + val proto = descriptorMap[path]!! + val deps = proto.dependencyList.filter { !visited.contains(it) }.map { fileDescriptor(it, + visited + path) } + return Descriptors.FileDescriptor.buildFrom(proto, deps.toTypedArray()) + } + + private fun createDescriptorMap0(): Map { + val subMap = mapOf( + "service.proto" to descriptorFor(arrayOf( + "Cg1zZXJ2aWNlLnByb3RvEgR0ZXN0IgYKBFRlc3QyMgoLVGVzdFNlcnZpY2USIwoHVGVzdFJQQxIKLnRl", + "c3QuVGVzdBoKLnRlc3QuVGVzdDAB", + )), + ) + return subMap + } + + public fun getServiceDescriptor(): ServiceDescriptor? { + var result = serviceDescriptor + if (result == null) { + synchronized(TestServiceWireGrpc::class) { + result = serviceDescriptor + if (result == null) { + result = newBuilder(SERVICE_NAME) + .addMethod(getTestRPCMethod()) + .setSchemaDescriptor(io.grpc.protobuf.ProtoFileDescriptorSupplier { + fileDescriptor("service.proto", emptySet()) + }) + .build() + serviceDescriptor = result + } + } + } + return result + } + + public fun getTestRPCMethod(): MethodDescriptor { + var result: MethodDescriptor? = getTestRPCMethod + if (result == null) { + synchronized(TestServiceWireGrpc::class) { + result = getTestRPCMethod + if (result == null) { + getTestRPCMethod = MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.SERVER_STREAMING) + .setFullMethodName( + MethodDescriptor.generateFullMethodName( + "test.TestService", "TestRPC" + ) + ) + .setSampledToLocalTracing(true) + .setRequestMarshaller(TestServiceImplBase.TestMarshaller()) + .setResponseMarshaller(TestServiceImplBase.TestMarshaller()) + .build() + } + } + } + return getTestRPCMethod!! + } + + public fun newStub(channel: Channel): TestServiceStub = TestServiceStub(channel) + + public fun newBlockingStub(channel: Channel): TestServiceBlockingStub = + TestServiceBlockingStub(channel) + + public abstract class TestServiceImplBase : WireBindableService { + public open fun TestRPC(request: Test, response: StreamObserver): Unit = throw + UnsupportedOperationException() + + override fun bindService(): ServerServiceDefinition = + ServerServiceDefinition.builder(getServiceDescriptor()).addMethod( + getTestRPCMethod(), + serverCallsAsyncServerStreamingCall(this@TestServiceImplBase::TestRPC) + ).build() + + public class TestMarshaller : WireMethodMarshaller { + override fun stream(`value`: Test): InputStream = Test.ADAPTER.encode(value).inputStream() + + override fun marshalledClass(): Class = Test::class.java + + override fun parse(stream: InputStream): Test = Test.ADAPTER.decode(stream) + } + } + + public class BindableAdapter( + private val streamExecutor: ExecutorService, + private val service: () -> TestServiceBlockingServer, + ) : TestServiceImplBase() { + override fun TestRPC(request: Test, response: StreamObserver) { + service().TestRPC(request, MessageSinkAdapter(response)) + } + } + + public class TestServiceStub : AbstractStub { + internal constructor(channel: Channel) : super(channel) + + internal constructor(channel: Channel, callOptions: CallOptions) : super(channel, callOptions) + + override fun build(channel: Channel, callOptions: CallOptions): TestServiceStub = + TestServiceStub(channel, callOptions) + + public fun TestRPC(request: Test, response: StreamObserver) { + clientCallsAsyncServerStreamingCall(channel.newCall(getTestRPCMethod(), callOptions), request, + response) + } + } + + public class TestServiceBlockingStub : AbstractStub { + internal constructor(channel: Channel) : super(channel) + + internal constructor(channel: Channel, callOptions: CallOptions) : super(channel, callOptions) + + override fun build(channel: Channel, callOptions: CallOptions): TestServiceBlockingStub = + TestServiceBlockingStub(channel, callOptions) + } +} diff --git a/server-generator/src/test/java/com/squareup/wire/kotlin/grpcserver/StubTest.kt b/server-generator/src/test/java/com/squareup/wire/kotlin/grpcserver/StubTest.kt index 27cbe55..cf77d10 100644 --- a/server-generator/src/test/java/com/squareup/wire/kotlin/grpcserver/StubTest.kt +++ b/server-generator/src/test/java/com/squareup/wire/kotlin/grpcserver/StubTest.kt @@ -17,9 +17,12 @@ package com.squareup.wire.kotlin.grpcserver import com.squareup.kotlinpoet.FileSpec import com.squareup.kotlinpoet.TypeSpec +import com.squareup.wire.WireTestLogger import com.squareup.wire.buildSchema import com.squareup.wire.kotlin.grpcserver.GoldenTestUtils.assertFileEquals +import com.squareup.wire.schema.SchemaHandler import okio.Path.Companion.toPath +import okio.fakefilesystem.FakeFileSystem import org.junit.Test import kotlin.test.assertEquals @@ -185,6 +188,55 @@ class StubTest { ) } + + @Test + fun `generates stubs for blocking bidi streaming rpc`() { + val code = stubCodeForBlocking( + """ + syntax = "proto2"; + package test; + + message Test {} + service TestService { + rpc TestRPC(stream Test) returns (stream Test){} + } + """.trimMargin(), + ) + assertFileEquals("BlockingBidiStreamingService.kt", code) + } + + @Test + fun `generates stubs for blocking server streaming rpc`() { + val code = stubCodeForBlocking( + """ + syntax = "proto2"; + package test; + + message Test {} + service TestService { + rpc TestRPC(Test) returns (stream Test){} + } + """.trimMargin(), + ) + assertFileEquals("BlockingServerStreamingService.kt", code) + } + + @Test + fun `generates stubs for blocking client streaming rpc`() { + val code = stubCodeForBlocking( + """ + syntax = "proto2"; + package test; + + message Test {} + service TestService { + rpc TestRPC(stream Test) returns (Test){} + } + """.trimMargin(), + ) + assertFileEquals("BlockingClientStreamingService.kt", code) + } + private fun stubCodeFor( pkg: String, serviceName: String, @@ -207,4 +259,33 @@ class StubTest { .toString() .trim() } + + private fun stubCodeForBlocking( + schemaCode: String, + ): String { + val fileSystem = FakeFileSystem() + val outDirectory = "generated/wire" + val protoPath = "service.proto" + val schema = buildSchema { add(protoPath.toPath(), schemaCode) } + val context = SchemaHandler.Context( + fileSystem = fileSystem, + outDirectory = outDirectory.toPath(), + logger = WireTestLogger(), + sourcePathPaths = setOf(protoPath), + ) + GrpcServerSchemaHandler.Factory().create( + includes = listOf(), + excludes = listOf(), + exclusive = true, + outDirectory = outDirectory, + options = mapOf( + "singleMethodServices" to "false", + "rpcCallStyle" to "blocking", + ), + ) + .handle(schema, context) + return fileSystem.read("generated/wire/test/TestServiceWireGrpc.kt".toPath()) { + readUtf8() + } + } } From 3a1eeac7851ca0d0f1a0f2feeb518b26d9d07f49 Mon Sep 17 00:00:00 2001 From: Sachin Mehta Date: Fri, 18 Oct 2024 11:18:29 +0530 Subject: [PATCH 2/5] BlockingStub type was not extending BlockingStub. --- .../kotlin/grpcserver/BlockingStubGenerator.kt | 1 + .../wire/kotlin/grpcserver/StubGenerator.kt | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/server-generator/src/main/java/com/squareup/wire/kotlin/grpcserver/BlockingStubGenerator.kt b/server-generator/src/main/java/com/squareup/wire/kotlin/grpcserver/BlockingStubGenerator.kt index 4b39008..a863d01 100644 --- a/server-generator/src/main/java/com/squareup/wire/kotlin/grpcserver/BlockingStubGenerator.kt +++ b/server-generator/src/main/java/com/squareup/wire/kotlin/grpcserver/BlockingStubGenerator.kt @@ -56,6 +56,7 @@ object BlockingStubGenerator { this, service, ClassName("io.grpc.stub", "AbstractStub"), + true, ) } .addBlockingStubRpcCalls(generator, service) diff --git a/server-generator/src/main/java/com/squareup/wire/kotlin/grpcserver/StubGenerator.kt b/server-generator/src/main/java/com/squareup/wire/kotlin/grpcserver/StubGenerator.kt index af33486..cb5f82c 100644 --- a/server-generator/src/main/java/com/squareup/wire/kotlin/grpcserver/StubGenerator.kt +++ b/server-generator/src/main/java/com/squareup/wire/kotlin/grpcserver/StubGenerator.kt @@ -68,6 +68,7 @@ object StubGenerator { this, service, ClassName("io.grpc.kotlin", "AbstractCoroutineStub"), + false, ) addSuspendedStubRpcCalls(generator, this, service, options) } @@ -98,7 +99,13 @@ object StubGenerator { .addType( TypeSpec.classBuilder(stubClassName) .apply { - addAbstractStubConstructor(generator, this, service, ClassName("io.grpc.stub", "AbstractStub")) + addAbstractStubConstructor( + generator, + this, + service, + ClassName("io.grpc.stub", "AbstractStub"), + false + ) addStubRpcCalls(generator, this, service, options) } .build(), @@ -110,12 +117,14 @@ object StubGenerator { builder: TypeSpec.Builder, service: Service, superClass: ClassName, + blockingStub: Boolean, ): TypeSpec.Builder { + val stubType = if (blockingStub) "Blocking" else "" val serviceClassName = generator.classNameFor(service.type) val stubClassName = ClassName( packageName = serviceClassName.packageName, "${serviceClassName.simpleName}WireGrpc", - "${serviceClassName.simpleName}Stub", + "${serviceClassName.simpleName}${stubType}Stub", ) return builder // Really this is a superclass, just want to add secondary constructors. @@ -140,8 +149,8 @@ object StubGenerator { .addModifiers(KModifier.OVERRIDE) .addParameter("channel", ClassName("io.grpc", "Channel")) .addParameter("callOptions", ClassName("io.grpc", "CallOptions")) - .addStatement("return ${service.name}Stub(channel, callOptions)") - .returns(ClassName("", "${service.name}Stub")) + .addStatement("return ${service.name}${stubType}Stub(channel, callOptions)") + .returns(ClassName("", "${service.name}${stubType}Stub")) .build(), ) } From 6e9e218e602690a464ead4e8348f9c3026bd23c7 Mon Sep 17 00:00:00 2001 From: Marius Volkhart Date: Mon, 21 Jul 2025 20:28:47 -0400 Subject: [PATCH 3/5] fixup! Adding tests for blocking stub generation. --- server-generator/src/test/golden/RouteGuideWireGrpc.kt | 6 +++--- server-generator/src/test/golden/nonSingleMethodService.kt | 6 +++--- server-generator/src/test/golden/singleMethodService.kt | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/server-generator/src/test/golden/RouteGuideWireGrpc.kt b/server-generator/src/test/golden/RouteGuideWireGrpc.kt index 1eb0e9b..2640711 100644 --- a/server-generator/src/test/golden/RouteGuideWireGrpc.kt +++ b/server-generator/src/test/golden/RouteGuideWireGrpc.kt @@ -345,13 +345,13 @@ public object RouteGuideWireGrpc { response) } - public class RouteGuideBlockingStub : AbstractStub { + public class RouteGuideBlockingStub : AbstractStub { internal constructor(channel: Channel) : super(channel) internal constructor(channel: Channel, callOptions: CallOptions) : super(channel, callOptions) - override fun build(channel: Channel, callOptions: CallOptions): RouteGuideStub = - RouteGuideStub(channel, callOptions) + override fun build(channel: Channel, callOptions: CallOptions): RouteGuideBlockingStub = + RouteGuideBlockingStub(channel, callOptions) public fun GetFeature(request: Point): Feature = blockingUnaryCall(channel, getGetFeatureMethod(), callOptions, request) diff --git a/server-generator/src/test/golden/nonSingleMethodService.kt b/server-generator/src/test/golden/nonSingleMethodService.kt index 7246a7c..674d1bf 100644 --- a/server-generator/src/test/golden/nonSingleMethodService.kt +++ b/server-generator/src/test/golden/nonSingleMethodService.kt @@ -205,13 +205,13 @@ public object FooServiceWireGrpc { } } - public class FooServiceBlockingStub : AbstractStub { + public class FooServiceBlockingStub : AbstractStub { internal constructor(channel: Channel) : super(channel) internal constructor(channel: Channel, callOptions: CallOptions) : super(channel, callOptions) - override fun build(channel: Channel, callOptions: CallOptions): FooServiceStub = - FooServiceStub(channel, callOptions) + override fun build(channel: Channel, callOptions: CallOptions): FooServiceBlockingStub = + FooServiceBlockingStub(channel, callOptions) public fun Call1(request: Request): Response = blockingUnaryCall(channel, getCall1Method(), callOptions, request) diff --git a/server-generator/src/test/golden/singleMethodService.kt b/server-generator/src/test/golden/singleMethodService.kt index 401b121..eb0ef0e 100644 --- a/server-generator/src/test/golden/singleMethodService.kt +++ b/server-generator/src/test/golden/singleMethodService.kt @@ -206,13 +206,13 @@ public object FooServiceWireGrpc { } } - public class FooServiceBlockingStub : AbstractStub { + public class FooServiceBlockingStub : AbstractStub { internal constructor(channel: Channel) : super(channel) internal constructor(channel: Channel, callOptions: CallOptions) : super(channel, callOptions) - override fun build(channel: Channel, callOptions: CallOptions): FooServiceStub = - FooServiceStub(channel, callOptions) + override fun build(channel: Channel, callOptions: CallOptions): FooServiceBlockingStub = + FooServiceBlockingStub(channel, callOptions) public fun Call1(request: Request): Response = blockingUnaryCall(channel, getCall1Method(), callOptions, request) From cbf1ab10e72f69ef31d3cfdab09df2fb987ad7db Mon Sep 17 00:00:00 2001 From: Marius Volkhart Date: Tue, 22 Jul 2025 08:24:19 -0400 Subject: [PATCH 4/5] fixup! Adding tests for blocking stub generation. --- .../test/java/com/squareup/wire/kotlin/grpcserver/StubTest.kt | 1 - 1 file changed, 1 deletion(-) diff --git a/server-generator/src/test/java/com/squareup/wire/kotlin/grpcserver/StubTest.kt b/server-generator/src/test/java/com/squareup/wire/kotlin/grpcserver/StubTest.kt index cf77d10..19ce96b 100644 --- a/server-generator/src/test/java/com/squareup/wire/kotlin/grpcserver/StubTest.kt +++ b/server-generator/src/test/java/com/squareup/wire/kotlin/grpcserver/StubTest.kt @@ -188,7 +188,6 @@ class StubTest { ) } - @Test fun `generates stubs for blocking bidi streaming rpc`() { val code = stubCodeForBlocking( From f81099357f1412011a2371686c05432d18780219 Mon Sep 17 00:00:00 2001 From: Marius Volkhart Date: Tue, 22 Jul 2025 08:24:36 -0400 Subject: [PATCH 5/5] fixup! BlockingStub type was not extending BlockingStub. --- .../java/com/squareup/wire/kotlin/grpcserver/StubGenerator.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server-generator/src/main/java/com/squareup/wire/kotlin/grpcserver/StubGenerator.kt b/server-generator/src/main/java/com/squareup/wire/kotlin/grpcserver/StubGenerator.kt index cb5f82c..c58f465 100644 --- a/server-generator/src/main/java/com/squareup/wire/kotlin/grpcserver/StubGenerator.kt +++ b/server-generator/src/main/java/com/squareup/wire/kotlin/grpcserver/StubGenerator.kt @@ -104,7 +104,7 @@ object StubGenerator { this, service, ClassName("io.grpc.stub", "AbstractStub"), - false + false, ) addStubRpcCalls(generator, this, service, options) }