Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 3.x]
### Added
- Add circuit breaker support for gRPC transport to prevent out-of-memory errors ([#20203](https://github.com/opensearch-project/OpenSearch/pull/20203))
- Allow setting index.creation_date on index creation and restore for plugin compatibility and migrations ([#19931](https://github.com/opensearch-project/OpenSearch/pull/19931))
- Add support for a ForkJoinPool type ([#19008](https://github.com/opensearch-project/OpenSearch/pull/19008))
- Add seperate shard limit validation for local and remote indices ([#19532](https://github.com/opensearch-project/OpenSearch/pull/19532))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public Map<String, Supplier<AuxTransport>> getAuxTransports(

return Collections.singletonMap(GRPC_TRANSPORT_SETTING_KEY, () -> {
List<BindableService> grpcServices = new ArrayList<>(
List.of(new DocumentServiceImpl(client), new SearchServiceImpl(client, queryUtils))
List.of(new DocumentServiceImpl(client, circuitBreakerService), new SearchServiceImpl(client, queryUtils))
);
for (GrpcServiceFactory serviceFac : servicesFactory) {
List<BindableService> pluginServices = serviceFac.initClient(client)
Expand Down Expand Up @@ -234,7 +234,7 @@ public Map<String, Supplier<AuxTransport>> getSecureAuxTransports(
}
return Collections.singletonMap(GRPC_SECURE_TRANSPORT_SETTING_KEY, () -> {
List<BindableService> grpcServices = new ArrayList<>(
List.of(new DocumentServiceImpl(client), new SearchServiceImpl(client, queryUtils))
List.of(new DocumentServiceImpl(client, circuitBreakerService), new SearchServiceImpl(client, queryUtils))
);
for (GrpcServiceFactory serviceFac : servicesFactory) {
List<BindableService> pluginServices = serviceFac.initClient(client)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.indices.breaker.CircuitBreakerService;
import org.opensearch.protobufs.services.DocumentServiceGrpc;
import org.opensearch.transport.client.Client;
import org.opensearch.transport.grpc.listeners.BulkRequestActionListener;
Expand All @@ -19,35 +21,76 @@
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;

import java.util.concurrent.atomic.AtomicBoolean;

/**
* Implementation of the gRPC Document Service.
*/
public class DocumentServiceImpl extends DocumentServiceGrpc.DocumentServiceImplBase {
private static final Logger logger = LogManager.getLogger(DocumentServiceImpl.class);
private final Client client;
private final CircuitBreakerService circuitBreakerService;

/**
* Creates a new DocumentServiceImpl.
*
* @param client Client for executing actions on the local node
* @param circuitBreakerService Circuit breaker service for memory protection
*/
public DocumentServiceImpl(Client client) {
public DocumentServiceImpl(Client client, CircuitBreakerService circuitBreakerService) {
this.client = client;
this.circuitBreakerService = circuitBreakerService;
}

/**
* Processes a bulk request.
* Checks circuit breakers before processing, similar to how REST API handles requests.
*
* @param request The bulk request to process
* @param responseObserver The observer to send the response back to the client
*/
@Override
public void bulk(org.opensearch.protobufs.BulkRequest request, StreamObserver<org.opensearch.protobufs.BulkResponse> responseObserver) {
final int contentLength = request.getSerializedSize();
CircuitBreaker inFlightRequestsBreaker = circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS);
final AtomicBoolean closed = new AtomicBoolean(false);

try {
inFlightRequestsBreaker.addEstimateBytesAndMaybeBreak(contentLength, "<grpc_bulk_request>");
Comment on lines +54 to +59
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Bytes released even when breaker trips before adding them.

If addEstimateBytesAndMaybeBreak() throws a CircuitBreakingException, the bytes were not added to the breaker. However, the catch block at lines 91-93 unconditionally calls addWithoutBreaking(-contentLength), which will subtract bytes that were never reserved. This corrupts the circuit breaker accounting and can lead to negative byte tracking.

Track whether bytes were successfully added before attempting release:

     final int contentLength = request.getSerializedSize();
     CircuitBreaker inFlightRequestsBreaker = circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS);
     final AtomicBoolean closed = new AtomicBoolean(false);
+    boolean bytesReserved = false;

     try {
         inFlightRequestsBreaker.addEstimateBytesAndMaybeBreak(contentLength, "<grpc_bulk_request>");
+        bytesReserved = true;

         org.opensearch.action.bulk.BulkRequest bulkRequest = BulkRequestProtoUtils.prepareRequest(request);
         // ... (listener code unchanged)

         client.bulk(bulkRequest, wrappedListener);
     } catch (RuntimeException e) {
-        if (closed.compareAndSet(false, true)) {
+        if (bytesReserved && closed.compareAndSet(false, true)) {
             inFlightRequestsBreaker.addWithoutBreaking(-contentLength);
         }
         logger.debug("DocumentServiceImpl failed: {} - {}", e.getClass().getSimpleName(), e.getMessage());

Also applies to: 90-93

🤖 Prompt for AI Agents
In
modules/transport-grpc/src/main/java/org/opensearch/transport/grpc/services/DocumentServiceImpl.java
around lines 54-59 (and also address lines 90-93), the code currently calls
inFlightRequestsBreaker.addEstimateBytesAndMaybeBreak(contentLength, ...) and
later unconditionally calls addWithoutBreaking(-contentLength) even if the
addEstimateBytesAndMaybeBreak threw, which can decrement bytes that were never
added; fix by introducing a boolean flag (e.g., bytesReserved) set to true only
after addEstimateBytesAndMaybeBreak returns successfully, and only call
addWithoutBreaking(-contentLength) in the cleanup/close path when bytesReserved
is true (ensure the same pattern is applied to the code at lines 90-93) so you
never release bytes that were not reserved.


org.opensearch.action.bulk.BulkRequest bulkRequest = BulkRequestProtoUtils.prepareRequest(request);
BulkRequestActionListener listener = new BulkRequestActionListener(responseObserver);
client.bulk(bulkRequest, listener);

BulkRequestActionListener baseListener = new BulkRequestActionListener(responseObserver);
org.opensearch.core.action.ActionListener<org.opensearch.action.bulk.BulkResponse> wrappedListener =
new org.opensearch.core.action.ActionListener<org.opensearch.action.bulk.BulkResponse>() {
@Override
public void onResponse(org.opensearch.action.bulk.BulkResponse response) {
try {
baseListener.onResponse(response);
} finally {
if (closed.compareAndSet(false, true)) {
inFlightRequestsBreaker.addWithoutBreaking(-contentLength);
}
}
}

@Override
public void onFailure(Exception e) {
try {
baseListener.onFailure(e);
} finally {
if (closed.compareAndSet(false, true)) {
inFlightRequestsBreaker.addWithoutBreaking(-contentLength);
}
}
}
};

client.bulk(bulkRequest, wrappedListener);
} catch (RuntimeException e) {
if (closed.compareAndSet(false, true)) {
inFlightRequestsBreaker.addWithoutBreaking(-contentLength);
}
logger.debug("DocumentServiceImpl failed: {} - {}", e.getClass().getSimpleName(), e.getMessage());
StatusRuntimeException grpcError = GrpcErrorHandler.convertToGrpcError(e);
responseObserver.onError(grpcError);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import io.grpc.BindableService;
import io.grpc.Metadata;
Expand Down Expand Up @@ -629,7 +630,9 @@ private ExtensiblePlugin.ExtensionLoader createMockLoader(List<Integer> orders)
when(mockProvider.getOrderedGrpcInterceptors(Mockito.any())).thenReturn(new ArrayList<>());
when(mockLoader.loadExtensions(GrpcInterceptorProvider.class)).thenReturn(List.of(mockProvider));
} else {
List<OrderedGrpcInterceptor> interceptors = orders.stream().map(order -> createMockInterceptor(order)).toList();
List<OrderedGrpcInterceptor> interceptors = orders.stream()
.map(order -> createMockInterceptor(order))
.collect(Collectors.toList());

GrpcInterceptorProvider mockProvider = Mockito.mock(GrpcInterceptorProvider.class);
when(mockProvider.getOrderedGrpcInterceptors(Mockito.any())).thenReturn(interceptors);
Expand All @@ -647,11 +650,11 @@ private ExtensiblePlugin.ExtensionLoader createMockLoaderWithMultipleProviders(L
when(mockLoader.loadExtensions(QueryBuilderProtoConverter.class)).thenReturn(null);

List<GrpcInterceptorProvider> providers = providerOrders.stream().map(orders -> {
List<OrderedGrpcInterceptor> interceptors = orders.stream().map(this::createMockInterceptor).toList();
List<OrderedGrpcInterceptor> interceptors = orders.stream().map(this::createMockInterceptor).collect(Collectors.toList());
GrpcInterceptorProvider provider = Mockito.mock(GrpcInterceptorProvider.class);
when(provider.getOrderedGrpcInterceptors(Mockito.any())).thenReturn(interceptors);
return provider;
}).toList();
}).collect(Collectors.toList());

when(mockLoader.loadExtensions(GrpcInterceptorProvider.class)).thenReturn(providers);
return mockLoader;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
package org.opensearch.transport.grpc.services.document;

import com.google.protobuf.ByteString;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.indices.breaker.CircuitBreakerService;
import org.opensearch.protobufs.BulkRequest;
import org.opensearch.protobufs.BulkRequestBody;
import org.opensearch.protobufs.IndexOperation;
Expand All @@ -20,12 +25,20 @@
import java.io.IOException;

import io.grpc.stub.StreamObserver;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class DocumentServiceImplTests extends OpenSearchTestCase {

Expand All @@ -34,13 +47,20 @@ public class DocumentServiceImplTests extends OpenSearchTestCase {
@Mock
private NodeClient client;

@Mock
private CircuitBreakerService circuitBreakerService;

@Mock
private CircuitBreaker circuitBreaker;

@Mock
private StreamObserver<org.opensearch.protobufs.BulkResponse> responseObserver;

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
service = new DocumentServiceImpl(client);
when(circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS)).thenReturn(circuitBreaker);
service = new DocumentServiceImpl(client, circuitBreakerService);
}

public void testBulkSuccess() throws IOException {
Expand Down Expand Up @@ -68,6 +88,120 @@ public void testBulkError() throws IOException {
verify(responseObserver).onError(any(RuntimeException.class));
}

public void testCircuitBreakerCheckedBeforeProcessing() throws IOException {
// Create a test request
BulkRequest request = createTestBulkRequest();

// Call the bulk method
service.bulk(request, responseObserver);

// Verify circuit breaker was checked with the request size
verify(circuitBreaker).addEstimateBytesAndMaybeBreak(anyLong(), eq("<grpc_bulk_request>"));

// Verify client.bulk was called
verify(client).bulk(any(org.opensearch.action.bulk.BulkRequest.class), any());
}

public void testCircuitBreakerTripsAndRejectsRequest() throws IOException {
// Create a test request
BulkRequest request = createTestBulkRequest();

// Make circuit breaker throw exception
CircuitBreakingException circuitBreakerException = new CircuitBreakingException(
"Data too large",
100L,
50 * 1024 * 1024L,
CircuitBreaker.Durability.TRANSIENT
);
doThrow(circuitBreakerException).when(circuitBreaker).addEstimateBytesAndMaybeBreak(anyLong(), anyString());

// Call the bulk method
service.bulk(request, responseObserver);

// Verify circuit breaker was checked
verify(circuitBreaker).addEstimateBytesAndMaybeBreak(anyLong(), eq("<grpc_bulk_request>"));

// Verify client.bulk was NOT called (request was rejected before processing)
verify(client, never()).bulk(any(org.opensearch.action.bulk.BulkRequest.class), any());

// Verify bytes were released after rejection
verify(circuitBreaker).addWithoutBreaking(anyLong());

// Verify error was sent to client
verify(responseObserver).onError(any());
}

public void testCircuitBreakerBytesReleasedOnSuccess() throws IOException {
// Create a test request
BulkRequest request = createTestBulkRequest();

// Capture the ActionListener to simulate success
@SuppressWarnings("unchecked")
ArgumentCaptor<ActionListener<BulkResponse>> listenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

// Call the bulk method
service.bulk(request, responseObserver);

// Verify client.bulk was called and capture the listener
verify(client).bulk(any(org.opensearch.action.bulk.BulkRequest.class), listenerCaptor.capture());

// Simulate successful response
BulkResponse mockResponse = mock(BulkResponse.class);
when(mockResponse.hasFailures()).thenReturn(false);
listenerCaptor.getValue().onResponse(mockResponse);

// Verify bytes were released after success (negative value)
verify(circuitBreaker).addWithoutBreaking(anyLong());
}

public void testCircuitBreakerBytesReleasedOnFailure() throws IOException {
// Create a test request
BulkRequest request = createTestBulkRequest();

// Capture the ActionListener to simulate failure
@SuppressWarnings("unchecked")
ArgumentCaptor<ActionListener<BulkResponse>> listenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

// Call the bulk method
service.bulk(request, responseObserver);

// Verify client.bulk was called and capture the listener
verify(client).bulk(any(org.opensearch.action.bulk.BulkRequest.class), listenerCaptor.capture());

// Simulate failure
Exception testException = new RuntimeException("Bulk operation failed");
listenerCaptor.getValue().onFailure(testException);

// Verify bytes were released after failure (negative value)
verify(circuitBreaker).addWithoutBreaking(anyLong());
}

public void testCircuitBreakerBytesReleasedExactlyOnce() throws IOException {
// Create a test request
BulkRequest request = createTestBulkRequest();

// Capture the ActionListener
@SuppressWarnings("unchecked")
ArgumentCaptor<ActionListener<BulkResponse>> listenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

// Call the bulk method
service.bulk(request, responseObserver);

// Verify circuit breaker was checked
verify(circuitBreaker).addEstimateBytesAndMaybeBreak(anyLong(), eq("<grpc_bulk_request>"));

// Verify client.bulk was called and capture the listener
verify(client).bulk(any(org.opensearch.action.bulk.BulkRequest.class), listenerCaptor.capture());

// Simulate successful response
BulkResponse mockResponse = mock(BulkResponse.class);
when(mockResponse.hasFailures()).thenReturn(false);
listenerCaptor.getValue().onResponse(mockResponse);

// Verify bytes were released exactly once (no double-release)
verify(circuitBreaker, times(1)).addWithoutBreaking(anyLong());
}

private BulkRequest createTestBulkRequest() {
IndexOperation indexOp = IndexOperation.newBuilder().setXIndex("test-index").setXId("test-id").build();

Expand Down
Loading