Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
internal/mocks/*.go linguist-generated=true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI so that when changing these, the "View PR" view on Github says that the files are automatically generated and people can just mark "viewed" on them

*.pb.go linguist-generated=true
*.pb.*.go linguist-generated=true
proto/internal/buf.lock linguist-generated=true
7 changes: 5 additions & 2 deletions development/prometheus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ global:
scrape_configs:
- job_name: "spicedb"
static_configs:
- targets: ["spicedb:9090"]
- targets: ["spicedb-1:9090"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI so that we can verify the new metrics in Grafana

labels:
service: "spicedb"
service: "spicedb-1"
- targets: ["spicedb-2:9090"]
labels:
service: "spicedb-2"
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ require (
go.opentelemetry.io/otel/trace v1.38.0
go.uber.org/atomic v1.11.0
go.uber.org/goleak v1.3.0
go.uber.org/mock v0.6.0
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b
golang.org/x/mod v0.28.0
golang.org/x/sync v0.17.0
Expand All @@ -138,6 +139,8 @@ tool (
github.com/golangci/golangci-lint/v2/cmd/golangci-lint
// support running mage with go run mage.go
github.com/magefile/mage/mage
// mocks are generated with go:generate directives.
go.uber.org/mock/mockgen
// vulncheck always uses the current directory's go.mod.
golang.org/x/vuln/cmd/govulncheck
)
Expand Down Expand Up @@ -313,6 +316,7 @@ require (
github.com/klauspost/cpuid/v2 v2.2.5 // indirect
github.com/kulti/thelper v0.7.1 // indirect
github.com/kunwardeep/paralleltest v1.0.14 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 // indirect
github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 // indirect
github.com/lasiar/canonicalheader v1.1.2 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2595,6 +2595,8 @@ go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwE
go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
Expand Down
2 changes: 2 additions & 0 deletions internal/dispatch/dispatch.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:generate go run go.uber.org/mock/mockgen -source dispatch.go -destination ../mocks/mock_dispatcher.go -package mocks Dispatcher

package dispatch

import (
Expand Down
166 changes: 166 additions & 0 deletions internal/middleware/memoryprotection/memory_protection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
package memoryprotection

import (
"context"
"runtime/debug"
"strconv"
"strings"

middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

log "github.com/authzed/spicedb/internal/logging"
)

// RequestsProcessed tracks requests that were processed by this middleware.
var RequestsProcessed = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "spicedb",
Subsystem: "memory_middleware",
Name: "requests_processed_total",
Help: "Total requests processed by the memory protection middleware (flag --memory-protection-enabled)",
}, []string{"endpoint", "accepted"})

// Config holds configuration for the memory protection middleware
type Config struct {
// ThresholdPercent is the memory usage threshold for requests. If zero or negative, this middleware has no effect
ThresholdPercent float64
}

// DefaultConfig returns reasonable default configuration for API requests
func DefaultConfig() Config {
return Config{
ThresholdPercent: 0.90,
}
}

// DefaultDispatchConfig returns reasonable default configuration for dispatch requests
func DefaultDispatchConfig() Config {
return Config{
ThresholdPercent: 0.95,
}
}

// MemoryLimitProvider gets and sets the limit of memory usage.
// In production, use DefaultMemoryLimitProvider.
// For testing, use HardCodedMemoryLimitProvider.
type MemoryLimitProvider interface {
GetInBytes() int64
SetInBytes(int64)
}

var (
_ MemoryLimitProvider = (*DefaultMemoryLimitProvider)(nil)
_ MemoryLimitProvider = (*HardCodedMemoryLimitProvider)(nil)
)

type DefaultMemoryLimitProvider struct{}

func (p *DefaultMemoryLimitProvider) GetInBytes() int64 {
// SetMemoryLimit returns the previously set memory limit.
// A negative input does not adjust the limit, and allows for retrieval of the currently set memory limit
return debug.SetMemoryLimit(-1)
}

func (p *DefaultMemoryLimitProvider) SetInBytes(limit int64) {
debug.SetMemoryLimit(limit)
}

type HardCodedMemoryLimitProvider struct {
Hardcodedlimit int64
}

func (p *HardCodedMemoryLimitProvider) GetInBytes() int64 {
return p.Hardcodedlimit
}

func (p *HardCodedMemoryLimitProvider) SetInBytes(limit int64) {
p.Hardcodedlimit = limit

Check warning on line 81 in internal/middleware/memoryprotection/memory_protection.go

View check run for this annotation

Codecov / codecov/patch

internal/middleware/memoryprotection/memory_protection.go#L80-L81

Added lines #L80 - L81 were not covered by tests
}

type MemoryProtectionMiddleware struct {
config Config
sampler MemorySampler
}

// New creates a new memory admission middleware with the given sampler, which is assumed to have been started already.
func New(config Config, sampler MemorySampler, name string) *MemoryProtectionMiddleware {
am := MemoryProtectionMiddleware{
config: config,
sampler: sampler,
}

if am.disabled() {
log.Warn().Str("name", name).Msg("memory protection middleware disabled")
return &am
}

log.Info().
Str("name", name).
Float64("threshold_percent", config.ThresholdPercent).
Msg("memory protection middleware initialized")

return &am
}

// UnaryServerInterceptor returns a unary server interceptor that rejects incoming requests is memory usage is too high
func (am *MemoryProtectionMiddleware) UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if err := am.checkAdmission(info.FullMethod); err != nil {
return nil, err
}

return handler(ctx, req)
}
}

// StreamServerInterceptor returns a stream server interceptor that rejects incoming requests is memory usage is too high
func (am *MemoryProtectionMiddleware) StreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
if err := am.checkAdmission(info.FullMethod); err != nil {
return err
}

wrapped := middleware.WrapServerStream(stream)
return handler(srv, wrapped)
}
}

// checkAdmission returns an error if the request should be denied because memory usage is too high.
func (am *MemoryProtectionMiddleware) checkAdmission(method string) error {
if am.disabled() {
return nil
}

accept := true
defer func() {
am.recordMetric(method, accept)
}()

if am.sampler.GetMemoryUsagePercent() >= am.config.ThresholdPercent {
accept = false
return status.Error(codes.ResourceExhausted, "server rejected the request because memory usage is above configured threshold")
}

return nil
}

func (am *MemoryProtectionMiddleware) disabled() bool {
return am.config.ThresholdPercent <= 0
}

// recordMetric updates the RequestsProcessed metric and returns the endpoint type for the input method.
func (am *MemoryProtectionMiddleware) recordMetric(fullMethod string, accepted bool) string {
endpointType := "api"
if strings.HasPrefix(fullMethod, "/dispatch.v1.DispatchService") {
endpointType = "dispatch"
}

acceptedStr := strconv.FormatBool(accepted)

RequestsProcessed.WithLabelValues(endpointType, acceptedStr).Inc()
return endpointType
}
Loading
Loading