diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java index ab4eab3048..366277f711 100644 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -23,7 +23,9 @@ import com.google.protobuf.CodedInputStream; import com.google.protobuf.CodedOutputStream; import com.google.protobuf.WireFormat; +import io.grpc.Detachable; import io.grpc.Drainable; +import io.grpc.HasByteBuffer; import io.grpc.MethodDescriptor.Marshaller; import io.grpc.protobuf.ProtoUtils; import io.netty.buffer.ByteBuf; @@ -41,11 +43,12 @@ import java.util.Collections; import java.util.List; import org.apache.arrow.flight.grpc.AddWritableBuffer; -import org.apache.arrow.flight.grpc.GetReadableBuffer; import org.apache.arrow.flight.impl.Flight.FlightData; import org.apache.arrow.flight.impl.Flight.FlightDescriptor; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.ForeignAllocation; +import org.apache.arrow.memory.util.MemoryUtil; import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; @@ -55,10 +58,14 @@ import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.MetadataVersion; import org.apache.arrow.vector.types.pojo.Schema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** The in-memory representation of FlightData used to manage a stream of Arrow messages. */ class ArrowMessage implements AutoCloseable { + private static final Logger LOG = LoggerFactory.getLogger(ArrowMessage.class); + // If true, deserialize Arrow data by giving Arrow a reference to the underlying gRPC buffer // instead of copying the data. Defaults to true. public static final boolean ENABLE_ZERO_COPY_READ; @@ -312,8 +319,7 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s case APP_METADATA_TAG: { int size = readRawVarint32(stream); - appMetadata = allocator.buffer(size); - GetReadableBuffer.readIntoBuffer(stream, appMetadata, size, ENABLE_ZERO_COPY_READ); + appMetadata = readBuffer(allocator, stream, size); break; } case BODY_TAG: @@ -323,8 +329,7 @@ private static ArrowMessage frame(BufferAllocator allocator, final InputStream s body = null; } int size = readRawVarint32(stream); - body = allocator.buffer(size); - GetReadableBuffer.readIntoBuffer(stream, body, size, ENABLE_ZERO_COPY_READ); + body = readBuffer(allocator, stream, size); break; default: @@ -377,6 +382,114 @@ private static int readRawVarint32(int firstByte, InputStream is) throws IOExcep return CodedInputStream.readRawVarint32(firstByte, is); } + /** + * Reads data from the stream into an ArrowBuf, without copying data when possible. + * + *

First attempts to transfer ownership of the gRPC buffer to Arrow via {@link + * #wrapGrpcBuffer}. This avoids any memory copy when the gRPC transport provides a direct + * ByteBuffer (e.g., Netty). + * + *

If not possible (e.g., heap buffer, fragmented data, or unsupported transport), falls back + * to allocating a new buffer and copying data into it. + * + * @param allocator The allocator to use for buffer allocation + * @param stream The input stream to read from + * @param size The number of bytes to read + * @return An ArrowBuf containing the data + * @throws IOException if there is an error reading from the stream + */ + private static ArrowBuf readBuffer(BufferAllocator allocator, InputStream stream, int size) + throws IOException { + if (ENABLE_ZERO_COPY_READ) { + ArrowBuf zeroCopyBuf = wrapGrpcBuffer(stream, allocator, size); + if (zeroCopyBuf != null) { + return zeroCopyBuf; + } + } + + // Fall back to allocating and copying + ArrowBuf buf = allocator.buffer(size); + byte[] heapBytes = new byte[size]; + ByteStreams.readFully(stream, heapBytes); + buf.writeBytes(heapBytes); + buf.writerIndex(size); + return buf; + } + + /** + * Attempts to wrap gRPC's buffer as an ArrowBuf without copying. + * + *

This method takes ownership of gRPC's underlying buffer via {@link Detachable#detach()} and + * wraps it as an ArrowBuf using {@link BufferAllocator#wrapForeignAllocation}. The gRPC buffer + * will be released when the ArrowBuf is closed. + * + * @param stream The gRPC-provided InputStream + * @param allocator The allocator to use for wrapping the foreign allocation + * @param size The number of bytes to wrap + * @return An ArrowBuf wrapping gRPC's buffer, or {@code null} if zero-copy is not possible + */ + static ArrowBuf wrapGrpcBuffer( + final InputStream stream, final BufferAllocator allocator, final int size) { + + if (!(stream instanceof Detachable) || !(stream instanceof HasByteBuffer)) { + return null; + } + + HasByteBuffer hasByteBuffer = (HasByteBuffer) stream; + if (!hasByteBuffer.byteBufferSupported()) { + return null; + } + + ByteBuffer peekBuffer = hasByteBuffer.getByteBuffer(); + if (peekBuffer == null) { + return null; + } + if (!peekBuffer.isDirect()) { + return null; + } + if (peekBuffer.remaining() < size) { + // Data is fragmented across multiple buffers; zero-copy not possible + return null; + } + + // Take ownership + InputStream detachedStream = ((Detachable) stream).detach(); + + // Get buffer from detached stream + ByteBuffer detachedByteBuffer = ((HasByteBuffer) detachedStream).getByteBuffer(); + + // Calculate memory address accounting for buffer position + long baseAddress = MemoryUtil.getByteBufferAddress(detachedByteBuffer); + long dataAddress = baseAddress + detachedByteBuffer.position(); + + // Create ForeignAllocation with proper cleanup + ForeignAllocation foreignAllocation = + new ForeignAllocation(size, dataAddress) { + @Override + protected void release0() { + closeQuietly(detachedStream); + } + }; + + try { + return allocator.wrapForeignAllocation(foreignAllocation); + } catch (Throwable t) { + // If it fails, clean up the detached stream and propagate + closeQuietly(detachedStream); + throw t; + } + } + + private static void closeQuietly(InputStream stream) { + if (stream != null) { + try { + stream.close(); + } catch (IOException e) { + LOG.debug("Error closing detached gRPC stream", e); + } + } + } + /** * Convert the ArrowMessage to an InputStream. * diff --git a/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java b/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java deleted file mode 100644 index 45c32a86c6..0000000000 --- a/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.flight.grpc; - -import com.google.common.base.Throwables; -import com.google.common.io.ByteStreams; -import io.grpc.internal.ReadableBuffer; -import java.io.IOException; -import java.io.InputStream; -import java.lang.reflect.Field; -import org.apache.arrow.memory.ArrowBuf; - -/** - * Enable access to ReadableBuffer directly to copy data from a BufferInputStream into a target - * ByteBuffer/ByteBuf. - * - *

This could be solved by BufferInputStream exposing Drainable. - */ -public class GetReadableBuffer { - - private static final Field READABLE_BUFFER; - private static final Class BUFFER_INPUT_STREAM; - - static { - Field tmpField = null; - Class tmpClazz = null; - try { - Class clazz = Class.forName("io.grpc.internal.ReadableBuffers$BufferInputStream"); - - Field f = clazz.getDeclaredField("buffer"); - f.setAccessible(true); - // don't set until we've gotten past all exception cases. - tmpField = f; - tmpClazz = clazz; - } catch (Exception e) { - new RuntimeException("Failed to initialize GetReadableBuffer, falling back to slow path", e) - .printStackTrace(); - } - READABLE_BUFFER = tmpField; - BUFFER_INPUT_STREAM = tmpClazz; - } - - /** - * Extracts the ReadableBuffer for the given input stream. - * - * @param is Must be an instance of io.grpc.internal.ReadableBuffers$BufferInputStream or null - * will be returned. - */ - public static ReadableBuffer getReadableBuffer(InputStream is) { - - if (BUFFER_INPUT_STREAM == null || !is.getClass().equals(BUFFER_INPUT_STREAM)) { - return null; - } - - try { - return (ReadableBuffer) READABLE_BUFFER.get(is); - } catch (Exception ex) { - throw Throwables.propagate(ex); - } - } - - /** - * Helper method to read a gRPC-provided InputStream into an ArrowBuf. - * - * @param stream The stream to read from. Should be an instance of {@link #BUFFER_INPUT_STREAM}. - * @param buf The buffer to read into. - * @param size The number of bytes to read. - * @param fastPath Whether to enable the fast path (i.e. detect whether the stream is a {@link - * #BUFFER_INPUT_STREAM}). - * @throws IOException if there is an error reading form the stream - */ - public static void readIntoBuffer( - final InputStream stream, final ArrowBuf buf, final int size, final boolean fastPath) - throws IOException { - ReadableBuffer readableBuffer = fastPath ? getReadableBuffer(stream) : null; - if (readableBuffer != null) { - readableBuffer.readBytes(buf.nioBuffer(0, size)); - } else { - byte[] heapBytes = new byte[size]; - ByteStreams.readFully(stream, heapBytes); - buf.writeBytes(heapBytes); - } - buf.writerIndex(size); - } -} diff --git a/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java new file mode 100644 index 0000000000..099b1cd3e5 --- /dev/null +++ b/flight/flight-core/src/test/java/org/apache/arrow/flight/TestArrowMessageZeroCopy.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.flight; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.grpc.Detachable; +import io.grpc.HasByteBuffer; +import io.grpc.internal.ReadableBuffer; +import io.grpc.internal.ReadableBuffers; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Random; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class TestArrowMessageZeroCopy { + + private BufferAllocator allocator; + + @BeforeEach + public void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @AfterEach + public void tearDown() { + allocator.close(); + } + + private static InputStream createGrpcStreamWithDirectBuffer(byte[] data) { + ByteBuffer directBuffer = ByteBuffer.allocateDirect(data.length); + directBuffer.put(data); + directBuffer.flip(); + ReadableBuffer readableBuffer = ReadableBuffers.wrap(directBuffer); + return ReadableBuffers.openStream(readableBuffer, true); + } + + @Test + public void testWrapGrpcBufferReturnsNullForRegularInputStream() throws IOException { + byte[] testData = new byte[] {1, 2, 3, 4, 5}; + InputStream stream = new ByteArrayInputStream(testData); + + // ByteArrayInputStream doesn't implement Detachable or HasByteBuffer + ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); + assertNull(result, "Should return null for streams not implementing required interfaces"); + } + + @Test + public void testWrapGrpcBufferSucceedsForRealGrpcDirectBuffer() throws IOException { + byte[] testData = new byte[] {11, 22, 33, 44, 55}; + InputStream stream = createGrpcStreamWithDirectBuffer(testData); + + assertInstanceOf(Detachable.class, stream, "Real gRPC stream should implement Detachable"); + assertInstanceOf( + HasByteBuffer.class, stream, "Real gRPC stream should implement HasByteBuffer"); + assertTrue( + ((HasByteBuffer) stream).byteBufferSupported(), + "Direct buffer stream should support ByteBuffer"); + assertTrue( + ((HasByteBuffer) stream).getByteBuffer().isDirect(), + "Should have direct ByteBuffer backing"); + + try (ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length)) { + assertNotNull(result, "Should succeed for real gRPC stream with direct buffer"); + assertEquals(testData.length, result.capacity()); + + // Check received data is the same + byte[] readData = new byte[testData.length]; + result.getBytes(0, readData); + assertArrayEquals(testData, readData); + } + } + + @Test + public void testWrapGrpcBufferReturnsNullForRealGrpcHeapByteBuffer() throws IOException { + byte[] testData = new byte[] {1, 2, 3, 4, 5}; + ByteBuffer heapBuffer = ByteBuffer.wrap(testData); + ReadableBuffer readableBuffer = ReadableBuffers.wrap(heapBuffer); + + InputStream stream = ReadableBuffers.openStream(readableBuffer, true); + + assertInstanceOf(Detachable.class, stream, "Real gRPC stream should implement Detachable"); + assertInstanceOf( + HasByteBuffer.class, stream, "Real gRPC stream should implement HasByteBuffer"); + assertTrue( + ((HasByteBuffer) stream).byteBufferSupported(), + "Heap ByteBuffer stream should support ByteBuffer"); + assertFalse( + ((HasByteBuffer) stream).getByteBuffer().isDirect(), "Should have heap ByteBuffer backing"); + + // Zero-copy should return null for heap buffer (not direct) + ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); + assertNull(result, "Should return null for real gRPC stream with heap buffer"); + } + + @Test + public void testWrapGrpcBufferReturnsNullForRealGrpcByteArrayStream() throws IOException { + byte[] testData = new byte[] {1, 2, 3, 4, 5}; + ReadableBuffer readableBuffer = ReadableBuffers.wrap(testData); + InputStream stream = ReadableBuffers.openStream(readableBuffer, true); + + // Verify the stream has the expected gRPC interfaces + assertInstanceOf(Detachable.class, stream, "Real gRPC stream should implement Detachable"); + assertInstanceOf( + HasByteBuffer.class, stream, "Real gRPC stream should implement HasByteBuffer"); + // Byte array backed streams don't support ByteBuffer access + assertFalse( + ((HasByteBuffer) stream).byteBufferSupported(), + "Byte array stream should not support ByteBuffer"); + + // Zero-copy should return null when byteBufferSupported() is false + ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); + assertNull(result, "Should return null for real gRPC stream backed by byte array"); + } + + @Test + public void testWrapGrpcBufferMemoryAccountingWithRealGrpcStream() throws IOException { + byte[] testData = new byte[1024]; + new Random(42).nextBytes(testData); + InputStream stream = createGrpcStreamWithDirectBuffer(testData); + + long memoryBefore = allocator.getAllocatedMemory(); + assertEquals(0, memoryBefore); + + ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length); + assertNotNull(result, "Should succeed for real gRPC stream with direct buffer"); + + long memoryDuring = allocator.getAllocatedMemory(); + assertEquals(testData.length, memoryDuring); + + byte[] readData = new byte[testData.length]; + result.getBytes(0, readData); + assertArrayEquals(testData, readData); + + result.close(); + + long memoryAfter = allocator.getAllocatedMemory(); + assertEquals(0, memoryAfter); + } + + @Test + public void testWrapGrpcBufferReturnsNullForInsufficientDataWithRealGrpcStream() + throws IOException { + byte[] testData = new byte[] {1, 2, 3}; + InputStream stream = createGrpcStreamWithDirectBuffer(testData); + + // Request more data than available + ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, 10); + assertNull(result, "Should return null when buffer has insufficient data"); + } + + @Test + public void testWrapGrpcBufferLargeDataWithRealGrpcStream() throws IOException { + // Test with larger data (64KB) + byte[] testData = new byte[64 * 1024]; + new Random(42).nextBytes(testData); + InputStream stream = createGrpcStreamWithDirectBuffer(testData); + + try (ArrowBuf result = ArrowMessage.wrapGrpcBuffer(stream, allocator, testData.length)) { + assertNotNull(result, "Should succeed for large data with real gRPC stream"); + assertEquals(testData.length, result.capacity()); + + // Verify data integrity + byte[] readData = new byte[testData.length]; + result.getBytes(0, readData); + assertArrayEquals(testData, readData); + } + } +}