diff --git a/api/apiv1/design/database.go b/api/apiv1/design/database.go index fb293f5f..f5d6a0bb 100644 --- a/api/apiv1/design/database.go +++ b/api/apiv1/design/database.go @@ -9,7 +9,7 @@ const ( cpuPattern = `^[0-9]+(\.[0-9]{1,3}|m)?$` postgresVersionPattern = `^\d{2}\.\d{1,2}$` spockVersionPattern = `^\d{1}$` - serviceVersionPattern = `^(\d+\.\d+\.\d+|latest)$` + serviceVersionPattern = `^(\d+\.\d+(\.\d+)?|latest)$` ) var HostIDs = g.ArrayOf(Identifier, func() { @@ -166,10 +166,10 @@ var ServiceSpec = g.Type("ServiceSpec", func() { g.Meta("struct:tag:json", "service_type") }) g.Attribute("version", g.String, func() { - g.Description("The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'.") + g.Description("The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'.") g.Pattern(serviceVersionPattern) g.Example("1.0.0") - g.Example("1.2.3") + g.Example("14.5") g.Example("latest") g.Meta("struct:tag:json", "version") }) diff --git a/api/apiv1/gen/control_plane/service.go b/api/apiv1/gen/control_plane/service.go index 6da4a0a7..bdab95dd 100644 --- a/api/apiv1/gen/control_plane/service.go +++ b/api/apiv1/gen/control_plane/service.go @@ -969,8 +969,7 @@ type ServiceSpec struct { ServiceID Identifier `json:"service_id"` // The type of service to run. ServiceType string `json:"service_type"` - // The version of the service in semver format (e.g., '1.0.0') or the literal - // 'latest'. + // The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'. Version string `json:"version"` // The IDs of the hosts that should run this service. One service instance will // be created per host. diff --git a/api/apiv1/gen/http/control_plane/client/types.go b/api/apiv1/gen/http/control_plane/client/types.go index ffa31778..ccfe9b5b 100644 --- a/api/apiv1/gen/http/control_plane/client/types.go +++ b/api/apiv1/gen/http/control_plane/client/types.go @@ -2082,8 +2082,7 @@ type ServiceSpecRequestBody struct { ServiceID string `json:"service_id"` // The type of service to run. ServiceType string `json:"service_type"` - // The version of the service in semver format (e.g., '1.0.0') or the literal - // 'latest'. + // The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'. Version string `json:"version"` // The IDs of the hosts that should run this service. One service instance will // be created per host. @@ -2483,8 +2482,7 @@ type ServiceSpecResponseBody struct { ServiceID *string `json:"service_id"` // The type of service to run. ServiceType *string `json:"service_type"` - // The version of the service in semver format (e.g., '1.0.0') or the literal - // 'latest'. + // The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'. Version *string `json:"version"` // The IDs of the hosts that should run this service. One service instance will // be created per host. @@ -2812,8 +2810,7 @@ type ServiceSpecRequestBodyRequestBody struct { ServiceID string `json:"service_id"` // The type of service to run. ServiceType string `json:"service_type"` - // The version of the service in semver format (e.g., '1.0.0') or the literal - // 'latest'. + // The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'. Version string `json:"version"` // The IDs of the hosts that should run this service. One service instance will // be created per host. @@ -6256,7 +6253,7 @@ func ValidateServiceSpecRequestBody(body *ServiceSpecRequestBody) (err error) { if !(body.ServiceType == "mcp" || body.ServiceType == "postgrest" || body.ServiceType == "rag") { err = goa.MergeErrors(err, goa.InvalidEnumValueError("body.service_type", body.ServiceType, []any{"mcp", "postgrest", "rag"})) } - err = goa.MergeErrors(err, goa.ValidatePattern("body.version", body.Version, "^(\\d+\\.\\d+\\.\\d+|latest)$")) + err = goa.MergeErrors(err, goa.ValidatePattern("body.version", body.Version, "^(\\d+\\.\\d+(\\.\\d+)?|latest)$")) if len(body.HostIds) < 1 { err = goa.MergeErrors(err, goa.InvalidLengthError("body.host_ids", body.HostIds, len(body.HostIds), 1, true)) } @@ -7020,7 +7017,7 @@ func ValidateServiceSpecRequestBodyRequestBody(body *ServiceSpecRequestBodyReque if !(body.ServiceType == "mcp" || body.ServiceType == "postgrest" || body.ServiceType == "rag") { err = goa.MergeErrors(err, goa.InvalidEnumValueError("body.service_type", body.ServiceType, []any{"mcp", "postgrest", "rag"})) } - err = goa.MergeErrors(err, goa.ValidatePattern("body.version", body.Version, "^(\\d+\\.\\d+\\.\\d+|latest)$")) + err = goa.MergeErrors(err, goa.ValidatePattern("body.version", body.Version, "^(\\d+\\.\\d+(\\.\\d+)?|latest)$")) if len(body.HostIds) < 1 { err = goa.MergeErrors(err, goa.InvalidLengthError("body.host_ids", body.HostIds, len(body.HostIds), 1, true)) } diff --git a/api/apiv1/gen/http/control_plane/server/types.go b/api/apiv1/gen/http/control_plane/server/types.go index 1f3fd0b6..e8f68d2e 100644 --- a/api/apiv1/gen/http/control_plane/server/types.go +++ b/api/apiv1/gen/http/control_plane/server/types.go @@ -2166,8 +2166,7 @@ type ServiceSpecResponseBody struct { ServiceID string `json:"service_id"` // The type of service to run. ServiceType string `json:"service_type"` - // The version of the service in semver format (e.g., '1.0.0') or the literal - // 'latest'. + // The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'. Version string `json:"version"` // The IDs of the hosts that should run this service. One service instance will // be created per host. @@ -2494,8 +2493,7 @@ type ServiceSpecRequestBody struct { ServiceID *string `json:"service_id"` // The type of service to run. ServiceType *string `json:"service_type"` - // The version of the service in semver format (e.g., '1.0.0') or the literal - // 'latest'. + // The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'. Version *string `json:"version"` // The IDs of the hosts that should run this service. One service instance will // be created per host. @@ -2822,8 +2820,7 @@ type ServiceSpecRequestBodyRequestBody struct { ServiceID *string `json:"service_id"` // The type of service to run. ServiceType *string `json:"service_type"` - // The version of the service in semver format (e.g., '1.0.0') or the literal - // 'latest'. + // The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'. Version *string `json:"version"` // The IDs of the hosts that should run this service. One service instance will // be created per host. @@ -5766,7 +5763,7 @@ func ValidateServiceSpecRequestBody(body *ServiceSpecRequestBody) (err error) { } } if body.Version != nil { - err = goa.MergeErrors(err, goa.ValidatePattern("body.version", *body.Version, "^(\\d+\\.\\d+\\.\\d+|latest)$")) + err = goa.MergeErrors(err, goa.ValidatePattern("body.version", *body.Version, "^(\\d+\\.\\d+(\\.\\d+)?|latest)$")) } if len(body.HostIds) < 1 { err = goa.MergeErrors(err, goa.InvalidLengthError("body.host_ids", body.HostIds, len(body.HostIds), 1, true)) @@ -6508,7 +6505,7 @@ func ValidateServiceSpecRequestBodyRequestBody(body *ServiceSpecRequestBodyReque } } if body.Version != nil { - err = goa.MergeErrors(err, goa.ValidatePattern("body.version", *body.Version, "^(\\d+\\.\\d+\\.\\d+|latest)$")) + err = goa.MergeErrors(err, goa.ValidatePattern("body.version", *body.Version, "^(\\d+\\.\\d+(\\.\\d+)?|latest)$")) } if len(body.HostIds) < 1 { err = goa.MergeErrors(err, goa.InvalidLengthError("body.host_ids", body.HostIds, len(body.HostIds), 1, true)) diff --git a/api/apiv1/gen/http/openapi.json b/api/apiv1/gen/http/openapi.json index 878a1beb..ce65e7f5 100644 --- a/api/apiv1/gen/http/openapi.json +++ b/api/apiv1/gen/http/openapi.json @@ -8808,9 +8808,9 @@ }, "version": { "type": "string", - "description": "The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'.", + "description": "The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'.", "example": "latest", - "pattern": "^(\\d+\\.\\d+\\.\\d+|latest)$" + "pattern": "^(\\d+\\.\\d+(\\.\\d+)?|latest)$" } }, "example": { diff --git a/api/apiv1/gen/http/openapi.yaml b/api/apiv1/gen/http/openapi.yaml index f3ae0fe2..68bbaaab 100644 --- a/api/apiv1/gen/http/openapi.yaml +++ b/api/apiv1/gen/http/openapi.yaml @@ -6327,9 +6327,9 @@ definitions: - rag version: type: string - description: The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'. + description: The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'. example: latest - pattern: ^(\d+\.\d+\.\d+|latest)$ + pattern: ^(\d+\.\d+(\.\d+)?|latest)$ example: config: llm_model: gpt-4 diff --git a/api/apiv1/gen/http/openapi3.json b/api/apiv1/gen/http/openapi3.json index d118b29d..06ec3707 100644 --- a/api/apiv1/gen/http/openapi3.json +++ b/api/apiv1/gen/http/openapi3.json @@ -28077,9 +28077,9 @@ }, "version": { "type": "string", - "description": "The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'.", + "description": "The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'.", "example": "latest", - "pattern": "^(\\d+\\.\\d+\\.\\d+|latest)$" + "pattern": "^(\\d+\\.\\d+(\\.\\d+)?|latest)$" } }, "example": { @@ -28240,9 +28240,9 @@ }, "version": { "type": "string", - "description": "The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'.", + "description": "The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'.", "example": "latest", - "pattern": "^(\\d+\\.\\d+\\.\\d+|latest)$" + "pattern": "^(\\d+\\.\\d+(\\.\\d+)?|latest)$" } }, "example": { @@ -28406,9 +28406,9 @@ }, "version": { "type": "string", - "description": "The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'.", + "description": "The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'.", "example": "latest", - "pattern": "^(\\d+\\.\\d+\\.\\d+|latest)$" + "pattern": "^(\\d+\\.\\d+(\\.\\d+)?|latest)$" } }, "example": { @@ -28570,9 +28570,9 @@ }, "version": { "type": "string", - "description": "The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'.", + "description": "The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'.", "example": "latest", - "pattern": "^(\\d+\\.\\d+\\.\\d+|latest)$" + "pattern": "^(\\d+\\.\\d+(\\.\\d+)?|latest)$" } }, "example": { @@ -28736,9 +28736,9 @@ }, "version": { "type": "string", - "description": "The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'.", + "description": "The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'.", "example": "latest", - "pattern": "^(\\d+\\.\\d+\\.\\d+|latest)$" + "pattern": "^(\\d+\\.\\d+(\\.\\d+)?|latest)$" } }, "example": { @@ -28901,9 +28901,9 @@ }, "version": { "type": "string", - "description": "The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'.", + "description": "The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'.", "example": "latest", - "pattern": "^(\\d+\\.\\d+\\.\\d+|latest)$" + "pattern": "^(\\d+\\.\\d+(\\.\\d+)?|latest)$" } }, "example": { @@ -29066,9 +29066,9 @@ }, "version": { "type": "string", - "description": "The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'.", + "description": "The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'.", "example": "latest", - "pattern": "^(\\d+\\.\\d+\\.\\d+|latest)$" + "pattern": "^(\\d+\\.\\d+(\\.\\d+)?|latest)$" } }, "example": { diff --git a/api/apiv1/gen/http/openapi3.yaml b/api/apiv1/gen/http/openapi3.yaml index b7fb70e7..0bd85a9a 100644 --- a/api/apiv1/gen/http/openapi3.yaml +++ b/api/apiv1/gen/http/openapi3.yaml @@ -19877,9 +19877,9 @@ components: - rag version: type: string - description: The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'. + description: The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'. example: latest - pattern: ^(\d+\.\d+\.\d+|latest)$ + pattern: ^(\d+\.\d+(\.\d+)?|latest)$ example: config: llm_model: gpt-4 @@ -19995,9 +19995,9 @@ components: - rag version: type: string - description: The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'. + description: The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'. example: latest - pattern: ^(\d+\.\d+\.\d+|latest)$ + pattern: ^(\d+\.\d+(\.\d+)?|latest)$ example: config: llm_model: gpt-4 @@ -20116,9 +20116,9 @@ components: - rag version: type: string - description: The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'. + description: The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'. example: latest - pattern: ^(\d+\.\d+\.\d+|latest)$ + pattern: ^(\d+\.\d+(\.\d+)?|latest)$ example: config: llm_model: gpt-4 @@ -20235,9 +20235,9 @@ components: - rag version: type: string - description: The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'. + description: The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'. example: latest - pattern: ^(\d+\.\d+\.\d+|latest)$ + pattern: ^(\d+\.\d+(\.\d+)?|latest)$ example: config: llm_model: gpt-4 @@ -20356,9 +20356,9 @@ components: - rag version: type: string - description: The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'. + description: The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'. example: latest - pattern: ^(\d+\.\d+\.\d+|latest)$ + pattern: ^(\d+\.\d+(\.\d+)?|latest)$ example: config: llm_model: gpt-4 @@ -20476,9 +20476,9 @@ components: - rag version: type: string - description: The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'. + description: The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'. example: latest - pattern: ^(\d+\.\d+\.\d+|latest)$ + pattern: ^(\d+\.\d+(\.\d+)?|latest)$ example: config: llm_model: gpt-4 @@ -20596,9 +20596,9 @@ components: - rag version: type: string - description: The version of the service in semver format (e.g., '1.0.0') or the literal 'latest'. + description: The version of the service (e.g., '1.0.0', '14.5') or the literal 'latest'. example: latest - pattern: ^(\d+\.\d+\.\d+|latest)$ + pattern: ^(\d+\.\d+(\.\d+)?|latest)$ example: config: llm_model: gpt-4 diff --git a/server/internal/api/apiv1/validate.go b/server/internal/api/apiv1/validate.go index 7f212602..4bb532ab 100644 --- a/server/internal/api/apiv1/validate.go +++ b/server/internal/api/apiv1/validate.go @@ -613,7 +613,7 @@ func validateS3RepoProperties(props repoProperties, path []string) []error { } var pgBackRestOptionPattern = regexp.MustCompile(`^[a-z0-9-]+$`) -var semverPattern = regexp.MustCompile(`^\d+\.\d+\.\d+$`) +var semverPattern = regexp.MustCompile(`^\d+\.\d+(\.\d+)?$`) // reservedLabelPrefix is the label key prefix reserved for system use. const reservedLabelPrefix = "pgedge." diff --git a/server/internal/api/apiv1/validate_test.go b/server/internal/api/apiv1/validate_test.go index c75023c8..22f77c51 100644 --- a/server/internal/api/apiv1/validate_test.go +++ b/server/internal/api/apiv1/validate_test.go @@ -813,6 +813,16 @@ func TestValidateServiceSpec(t *testing.T) { }, }, }, + { + name: "valid PostgREST service with two-part version", + svc: &api.ServiceSpec{ + ServiceID: "postgrest", + ServiceType: "postgrest", + Version: "14.5", + HostIds: []api.Identifier{"host-1"}, + Config: map[string]any{}, + }, + }, { name: "valid MCP service with 'latest' version", svc: &api.ServiceSpec{ diff --git a/server/internal/orchestrator/swarm/postgrest_config.go b/server/internal/orchestrator/swarm/postgrest_config.go new file mode 100644 index 00000000..d92a917a --- /dev/null +++ b/server/internal/orchestrator/swarm/postgrest_config.go @@ -0,0 +1,47 @@ +package swarm + +import ( + "bytes" + "fmt" + + "github.com/pgEdge/control-plane/server/internal/database" +) + +// PostgRESTConfigParams holds all inputs needed to generate a postgrest.conf file. +type PostgRESTConfigParams struct { + Config *database.PostgRESTServiceConfig +} + +// GeneratePostgRESTConfig generates the postgrest.conf file content. +// Credentials are not written here; they are injected as libpq env vars at the container level. +func GeneratePostgRESTConfig(params *PostgRESTConfigParams) ([]byte, error) { + if params == nil { + return nil, fmt.Errorf("GeneratePostgRESTConfig: params must not be nil") + } + if params.Config == nil { + return nil, fmt.Errorf("GeneratePostgRESTConfig: params.Config must not be nil") + } + cfg := params.Config + + var buf bytes.Buffer + + fmt.Fprintf(&buf, "db-schemas = %q\n", cfg.DBSchemas) + fmt.Fprintf(&buf, "db-anon-role = %q\n", cfg.DBAnonRole) + fmt.Fprintf(&buf, "db-pool = %d\n", cfg.DBPool) + fmt.Fprintf(&buf, "db-max-rows = %d\n", cfg.MaxRows) + + if cfg.JWTSecret != nil { + fmt.Fprintf(&buf, "jwt-secret = %q\n", *cfg.JWTSecret) + } + if cfg.JWTAud != nil { + fmt.Fprintf(&buf, "jwt-aud = %q\n", *cfg.JWTAud) + } + if cfg.JWTRoleClaimKey != nil { + fmt.Fprintf(&buf, "jwt-role-claim-key = %q\n", *cfg.JWTRoleClaimKey) + } + if cfg.ServerCORSAllowedOrigins != nil { + fmt.Fprintf(&buf, "server-cors-allowed-origins = %q\n", *cfg.ServerCORSAllowedOrigins) + } + + return buf.Bytes(), nil +} diff --git a/server/internal/orchestrator/swarm/postgrest_config_resource.go b/server/internal/orchestrator/swarm/postgrest_config_resource.go new file mode 100644 index 00000000..cebb8e0b --- /dev/null +++ b/server/internal/orchestrator/swarm/postgrest_config_resource.go @@ -0,0 +1,132 @@ +package swarm + +import ( + "context" + "fmt" + "path/filepath" + + "github.com/samber/do" + "github.com/spf13/afero" + + "github.com/pgEdge/control-plane/server/internal/database" + "github.com/pgEdge/control-plane/server/internal/filesystem" + "github.com/pgEdge/control-plane/server/internal/resource" +) + +var _ resource.Resource = (*PostgRESTConfigResource)(nil) + +const ResourceTypePostgRESTConfig resource.Type = "swarm.postgrest_config" + +func PostgRESTConfigResourceIdentifier(serviceInstanceID string) resource.Identifier { + return resource.Identifier{ + ID: serviceInstanceID, + Type: ResourceTypePostgRESTConfig, + } +} + +// PostgRESTConfigResource manages the postgrest.conf file on the host filesystem. +// The file is bind-mounted read-only into the container; credentials are not included. +type PostgRESTConfigResource struct { + ServiceInstanceID string `json:"service_instance_id"` + ServiceID string `json:"service_id"` + HostID string `json:"host_id"` + DirResourceID string `json:"dir_resource_id"` + Config *database.PostgRESTServiceConfig `json:"config"` +} + +func (r *PostgRESTConfigResource) ResourceVersion() string { + return "1" +} + +func (r *PostgRESTConfigResource) DiffIgnore() []string { + return nil +} + +func (r *PostgRESTConfigResource) Identifier() resource.Identifier { + return PostgRESTConfigResourceIdentifier(r.ServiceInstanceID) +} + +func (r *PostgRESTConfigResource) Executor() resource.Executor { + return resource.HostExecutor(r.HostID) +} + +func (r *PostgRESTConfigResource) Dependencies() []resource.Identifier { + return []resource.Identifier{ + filesystem.DirResourceIdentifier(r.DirResourceID), + } +} + +func (r *PostgRESTConfigResource) TypeDependencies() []resource.Type { + return nil +} + +func (r *PostgRESTConfigResource) Refresh(ctx context.Context, rc *resource.Context) error { + fs, err := do.Invoke[afero.Fs](rc.Injector) + if err != nil { + return err + } + + dirPath, err := filesystem.DirResourceFullPath(rc, r.DirResourceID) + if err != nil { + return fmt.Errorf("failed to get service data dir path: %w", err) + } + + _, err = readResourceFile(fs, filepath.Join(dirPath, "postgrest.conf")) + if err != nil { + return fmt.Errorf("failed to read PostgREST config: %w", err) + } + + return nil +} + +func (r *PostgRESTConfigResource) Create(ctx context.Context, rc *resource.Context) error { + fs, err := do.Invoke[afero.Fs](rc.Injector) + if err != nil { + return err + } + + dirPath, err := filesystem.DirResourceFullPath(rc, r.DirResourceID) + if err != nil { + return fmt.Errorf("failed to get service data dir path: %w", err) + } + + return r.writeConfigFile(fs, dirPath) +} + +func (r *PostgRESTConfigResource) Update(ctx context.Context, rc *resource.Context) error { + fs, err := do.Invoke[afero.Fs](rc.Injector) + if err != nil { + return err + } + + dirPath, err := filesystem.DirResourceFullPath(rc, r.DirResourceID) + if err != nil { + return fmt.Errorf("failed to get service data dir path: %w", err) + } + + return r.writeConfigFile(fs, dirPath) +} + +func (r *PostgRESTConfigResource) Delete(ctx context.Context, rc *resource.Context) error { + // Cleanup is handled by the parent directory resource deletion. + return nil +} + +func (r *PostgRESTConfigResource) writeConfigFile(fs afero.Fs, dirPath string) error { + content, err := GeneratePostgRESTConfig(&PostgRESTConfigParams{ + Config: r.Config, + }) + if err != nil { + return fmt.Errorf("failed to generate PostgREST config: %w", err) + } + + configPath := filepath.Join(dirPath, "postgrest.conf") + if err := afero.WriteFile(fs, configPath, content, 0o600); err != nil { + return fmt.Errorf("failed to write %s: %w", configPath, err) + } + if err := fs.Chown(configPath, postgrestContainerUID, postgrestContainerUID); err != nil { + return fmt.Errorf("failed to change ownership for %s: %w", configPath, err) + } + + return nil +} diff --git a/server/internal/orchestrator/swarm/postgrest_config_test.go b/server/internal/orchestrator/swarm/postgrest_config_test.go new file mode 100644 index 00000000..9582d0ba --- /dev/null +++ b/server/internal/orchestrator/swarm/postgrest_config_test.go @@ -0,0 +1,186 @@ +package swarm + +import ( + "strings" + "testing" + + "github.com/pgEdge/control-plane/server/internal/database" +) + +// parseConf parses the key=value lines from a postgrest.conf into a map. +// String values are returned unquoted; numeric values are returned as-is. +func parseConf(t *testing.T, data []byte) map[string]string { + t.Helper() + m := make(map[string]string) + for _, line := range strings.Split(string(data), "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + parts := strings.SplitN(line, " = ", 2) + if len(parts) != 2 { + t.Fatalf("unexpected line in postgrest.conf: %q", line) + } + key := strings.TrimSpace(parts[0]) + val := strings.TrimSpace(parts[1]) + // Strip surrounding quotes from string values. + if strings.HasPrefix(val, `"`) && strings.HasSuffix(val, `"`) { + val = val[1 : len(val)-1] + } + m[key] = val + } + return m +} + +func TestGeneratePostgRESTConfig_Defaults(t *testing.T) { + params := &PostgRESTConfigParams{ + Config: &database.PostgRESTServiceConfig{ + DBSchemas: "public", + DBAnonRole: "pgedge_application_read_only", + DBPool: 10, + MaxRows: 1000, + }, + } + + data, err := GeneratePostgRESTConfig(params) + if err != nil { + t.Fatalf("GeneratePostgRESTConfig() error = %v", err) + } + + m := parseConf(t, data) + + if m["db-schemas"] != "public" { + t.Errorf("db-schemas = %q, want %q", m["db-schemas"], "public") + } + if m["db-anon-role"] != "pgedge_application_read_only" { + t.Errorf("db-anon-role = %q, want %q", m["db-anon-role"], "pgedge_application_read_only") + } + if m["db-pool"] != "10" { + t.Errorf("db-pool = %q, want %q", m["db-pool"], "10") + } + if m["db-max-rows"] != "1000" { + t.Errorf("db-max-rows = %q, want %q", m["db-max-rows"], "1000") + } +} + +func TestGeneratePostgRESTConfig_CustomCoreFields(t *testing.T) { + params := &PostgRESTConfigParams{ + Config: &database.PostgRESTServiceConfig{ + DBSchemas: "api,private", + DBAnonRole: "web_anon", + DBPool: 5, + MaxRows: 500, + }, + } + + data, err := GeneratePostgRESTConfig(params) + if err != nil { + t.Fatalf("GeneratePostgRESTConfig() error = %v", err) + } + + m := parseConf(t, data) + + if m["db-schemas"] != "api,private" { + t.Errorf("db-schemas = %q, want %q", m["db-schemas"], "api,private") + } + if m["db-anon-role"] != "web_anon" { + t.Errorf("db-anon-role = %q, want %q", m["db-anon-role"], "web_anon") + } + if m["db-pool"] != "5" { + t.Errorf("db-pool = %q, want %q", m["db-pool"], "5") + } + if m["db-max-rows"] != "500" { + t.Errorf("db-max-rows = %q, want %q", m["db-max-rows"], "500") + } +} + +func TestGeneratePostgRESTConfig_JWTFieldsAbsent(t *testing.T) { + // No JWT fields set — none should appear in the config file. + params := &PostgRESTConfigParams{ + Config: &database.PostgRESTServiceConfig{ + DBSchemas: "public", + DBAnonRole: "web_anon", + DBPool: 10, + MaxRows: 1000, + }, + } + + data, err := GeneratePostgRESTConfig(params) + if err != nil { + t.Fatalf("GeneratePostgRESTConfig() error = %v", err) + } + + m := parseConf(t, data) + + for _, key := range []string{"jwt-secret", "jwt-aud", "jwt-role-claim-key", "server-cors-allowed-origins"} { + if _, ok := m[key]; ok { + t.Errorf("%s should be absent when not configured, but it was present", key) + } + } +} + +func TestGeneratePostgRESTConfig_AllJWTFields(t *testing.T) { + secret := "a-very-long-jwt-secret-that-is-at-least-32-chars" + aud := "my-api-audience" + roleClaimKey := ".role" + corsOrigins := "https://example.com" + + params := &PostgRESTConfigParams{ + Config: &database.PostgRESTServiceConfig{ + DBSchemas: "public", + DBAnonRole: "web_anon", + DBPool: 10, + MaxRows: 1000, + JWTSecret: &secret, + JWTAud: &aud, + JWTRoleClaimKey: &roleClaimKey, + ServerCORSAllowedOrigins: &corsOrigins, + }, + } + + data, err := GeneratePostgRESTConfig(params) + if err != nil { + t.Fatalf("GeneratePostgRESTConfig() error = %v", err) + } + + m := parseConf(t, data) + + if m["jwt-secret"] != secret { + t.Errorf("jwt-secret = %q, want %q", m["jwt-secret"], secret) + } + if m["jwt-aud"] != aud { + t.Errorf("jwt-aud = %q, want %q", m["jwt-aud"], aud) + } + if m["jwt-role-claim-key"] != roleClaimKey { + t.Errorf("jwt-role-claim-key = %q, want %q", m["jwt-role-claim-key"], roleClaimKey) + } + if m["server-cors-allowed-origins"] != corsOrigins { + t.Errorf("server-cors-allowed-origins = %q, want %q", m["server-cors-allowed-origins"], corsOrigins) + } +} + +func TestGeneratePostgRESTConfig_CredentialsNotInFile(t *testing.T) { + // Verify that no credential-like keys ever appear in the config file. + secret := "a-very-long-jwt-secret-that-is-at-least-32-chars" + params := &PostgRESTConfigParams{ + Config: &database.PostgRESTServiceConfig{ + DBSchemas: "public", + DBAnonRole: "web_anon", + DBPool: 10, + MaxRows: 1000, + JWTSecret: &secret, + }, + } + + data, err := GeneratePostgRESTConfig(params) + if err != nil { + t.Fatalf("GeneratePostgRESTConfig() error = %v", err) + } + + // None of the libpq / db-uri credential keys should appear. + for _, forbidden := range []string{"db-uri", "PGUSER", "PGPASSWORD", "PGHOST", "PGPORT", "PGDATABASE"} { + if strings.Contains(string(data), forbidden) { + t.Errorf("config file must not contain %q (credentials are env vars)", forbidden) + } + } +} diff --git a/server/internal/orchestrator/swarm/postgrest_preflight_resource.go b/server/internal/orchestrator/swarm/postgrest_preflight_resource.go new file mode 100644 index 00000000..55b5a756 --- /dev/null +++ b/server/internal/orchestrator/swarm/postgrest_preflight_resource.go @@ -0,0 +1,134 @@ +package swarm + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/pgEdge/control-plane/server/internal/database" + "github.com/pgEdge/control-plane/server/internal/resource" +) + +var _ resource.Resource = (*PostgRESTPreflightResource)(nil) + +const ResourceTypePostgRESTPreflightResource resource.Type = "swarm.postgrest_preflight" + +func PostgRESTPreflightResourceIdentifier(serviceID string) resource.Identifier { + return resource.Identifier{ + ID: serviceID, + Type: ResourceTypePostgRESTPreflightResource, + } +} + +// PostgRESTPreflightResource validates that the configured schemas and anon role +// exist in the database before PostgREST is provisioned. It uses PrimaryExecutor +// so the check runs on a host with guaranteed database connectivity. +type PostgRESTPreflightResource struct { + ServiceID string `json:"service_id"` + DatabaseID string `json:"database_id"` + DatabaseName string `json:"database_name"` + NodeName string `json:"node_name"` + DBSchemas string `json:"db_schemas"` + DBAnonRole string `json:"db_anon_role"` +} + +func (r *PostgRESTPreflightResource) ResourceVersion() string { return "1" } +func (r *PostgRESTPreflightResource) DiffIgnore() []string { return nil } + +func (r *PostgRESTPreflightResource) Identifier() resource.Identifier { + return PostgRESTPreflightResourceIdentifier(r.ServiceID) +} + +func (r *PostgRESTPreflightResource) Executor() resource.Executor { + return resource.PrimaryExecutor(r.NodeName) +} + +func (r *PostgRESTPreflightResource) Dependencies() []resource.Identifier { + return nil +} + +func (r *PostgRESTPreflightResource) TypeDependencies() []resource.Type { + return nil +} + +// Refresh validates prerequisites and returns ErrNotFound only when validation +// fails, triggering a Create that surfaces the error. When prerequisites are +// satisfied the resource is considered up-to-date (no permadrift). +func (r *PostgRESTPreflightResource) Refresh(ctx context.Context, rc *resource.Context) error { + if err := r.validate(ctx, rc); err != nil { + return fmt.Errorf("%w: %s", resource.ErrNotFound, err.Error()) + } + return nil +} + +func (r *PostgRESTPreflightResource) Create(ctx context.Context, rc *resource.Context) error { + return r.validate(ctx, rc) +} + +func (r *PostgRESTPreflightResource) Update(ctx context.Context, rc *resource.Context) error { + return r.validate(ctx, rc) +} + +func (r *PostgRESTPreflightResource) Delete(ctx context.Context, rc *resource.Context) error { + return nil +} + +func (r *PostgRESTPreflightResource) validate(ctx context.Context, rc *resource.Context) error { + primary, err := database.GetPrimaryInstance(ctx, rc, r.NodeName) + if err != nil { + return fmt.Errorf("preflight: failed to get primary instance: %w", err) + } + conn, err := primary.Connection(ctx, rc, r.DatabaseName) + if err != nil { + return fmt.Errorf("preflight: failed to connect to database %s on node %s: %w", r.DatabaseName, r.NodeName, err) + } + defer conn.Close(ctx) + + var errs []error + + for _, schema := range splitSchemas(r.DBSchemas) { + var exists bool + if err := conn.QueryRow(ctx, + "SELECT EXISTS (SELECT 1 FROM information_schema.schemata WHERE schema_name = $1)", + schema, + ).Scan(&exists); err != nil { + errs = append(errs, fmt.Errorf("failed to check schema %q: %w", schema, err)) + continue + } + if !exists { + errs = append(errs, fmt.Errorf( + "schema %q does not exist in database %q; create it before deploying PostgREST", + schema, r.DatabaseName, + )) + } + } + + if r.DBAnonRole != "" { + var exists bool + if err := conn.QueryRow(ctx, + "SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_roles WHERE rolname = $1)", + r.DBAnonRole, + ).Scan(&exists); err != nil { + errs = append(errs, fmt.Errorf("failed to check role %q: %w", r.DBAnonRole, err)) + } else if !exists { + errs = append(errs, fmt.Errorf( + "role %q does not exist on the Postgres cluster; create it before deploying PostgREST", + r.DBAnonRole, + )) + } + } + + return errors.Join(errs...) +} + +func splitSchemas(s string) []string { + parts := strings.Split(s, ",") + schemas := make([]string, 0, len(parts)) + for _, p := range parts { + if p = strings.TrimSpace(p); p != "" { + schemas = append(schemas, p) + } + } + return schemas +} diff --git a/server/internal/orchestrator/swarm/resources.go b/server/internal/orchestrator/swarm/resources.go index 4878137f..e688039e 100644 --- a/server/internal/orchestrator/swarm/resources.go +++ b/server/internal/orchestrator/swarm/resources.go @@ -21,4 +21,6 @@ func RegisterResourceTypes(registry *resource.Registry) { resource.RegisterResourceType[*Switchover](registry, ResourceTypeSwitchover) resource.RegisterResourceType[*ScaleService](registry, ResourceTypeScaleService) resource.RegisterResourceType[*MCPConfigResource](registry, ResourceTypeMCPConfig) + resource.RegisterResourceType[*PostgRESTPreflightResource](registry, ResourceTypePostgRESTPreflightResource) + resource.RegisterResourceType[*PostgRESTConfigResource](registry, ResourceTypePostgRESTConfig) } diff --git a/server/internal/orchestrator/swarm/service_images.go b/server/internal/orchestrator/swarm/service_images.go index 7ff60d8b..c1d28fa6 100644 --- a/server/internal/orchestrator/swarm/service_images.go +++ b/server/internal/orchestrator/swarm/service_images.go @@ -49,6 +49,14 @@ func NewServiceVersions(cfg config.Config) *ServiceVersions { // No constraints — MCP works with all PG/Spock versions. }) + // PostgREST service versions. + // Images are published to the pgEdge registry under ghcr.io/pgedge/postgrest. + // The bare ref (no registry prefix) lets serviceImageTag prepend the + // configured ImageRepositoryHost (e.g. ghcr.io/pgedge). + versions.addServiceImage("postgrest", "14.5", &ServiceImage{ + Tag: serviceImageTag(cfg, "postgrest:14.5"), + }) + // RAG service versions // TODO: Register semver versions when official releases are published. versions.addServiceImage("rag", "latest", &ServiceImage{ diff --git a/server/internal/orchestrator/swarm/service_images_test.go b/server/internal/orchestrator/swarm/service_images_test.go index 17e35041..784600c6 100644 --- a/server/internal/orchestrator/swarm/service_images_test.go +++ b/server/internal/orchestrator/swarm/service_images_test.go @@ -1,6 +1,7 @@ package swarm import ( + "slices" "testing" "github.com/pgEdge/control-plane/server/internal/config" @@ -29,6 +30,13 @@ func TestGetServiceImage(t *testing.T) { wantTag: "ghcr.io/pgedge/postgres-mcp:latest", wantErr: false, }, + { + name: "valid postgrest 14.5", + serviceType: "postgrest", + version: "14.5", + wantTag: "ghcr.io/pgedge/postgrest:14.5", + wantErr: false, + }, { name: "unsupported service type", serviceType: "unknown", @@ -81,21 +89,29 @@ func TestSupportedServiceVersions(t *testing.T) { sv := NewServiceVersions(cfg) tests := []struct { - name string - serviceType string - wantLen int - wantErr bool + name string + serviceType string + wantLatest bool // whether "latest" must be present + minPinnedCount int // minimum number of pinned (non-"latest") versions required + wantErr bool }{ { - name: "mcp service has versions", - serviceType: "mcp", - wantLen: 1, // "latest" - wantErr: false, + name: "mcp service has versions", + serviceType: "mcp", + wantLatest: true, + minPinnedCount: 0, + wantErr: false, + }, + { + name: "postgrest service has versions", + serviceType: "postgrest", + wantLatest: false, + minPinnedCount: 1, // at least one pinned release (e.g. 14.5 or newer) + wantErr: false, }, { name: "unsupported service type", serviceType: "unknown", - wantLen: 0, wantErr: true, }, } @@ -107,8 +123,19 @@ func TestSupportedServiceVersions(t *testing.T) { t.Errorf("SupportedServiceVersions() error = %v, wantErr %v", err, tt.wantErr) return } - if len(got) != tt.wantLen { - t.Errorf("SupportedServiceVersions() returned %d versions, want %d", len(got), tt.wantLen) + if !tt.wantErr { + if tt.wantLatest && !slices.Contains(got, "latest") { + t.Errorf("SupportedServiceVersions() missing required version \"latest\", got %v", got) + } + pinned := 0 + for _, v := range got { + if v != "latest" { + pinned++ + } + } + if pinned < tt.minPinnedCount { + t.Errorf("SupportedServiceVersions() has %d pinned version(s), want at least %d", pinned, tt.minPinnedCount) + } } }) } @@ -183,6 +210,22 @@ func TestGetServiceImage_ConstraintsPopulated(t *testing.T) { } }) + t.Run("postgrest has no constraints", func(t *testing.T) { + img, err := sv.GetServiceImage("postgrest", "14.5") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if img.Tag != "ghcr.io/pgedge/postgrest:14.5" { + t.Errorf("expected ghcr.io/pgedge/postgrest:14.5, got %s", img.Tag) + } + if img.PostgresConstraint != nil { + t.Error("expected nil PostgresConstraint for postgrest") + } + if img.SpockConstraint != nil { + t.Error("expected nil SpockConstraint for postgrest") + } + }) + } func mustVersion(t *testing.T, s string) *host.Version { diff --git a/server/internal/orchestrator/swarm/service_instance_spec.go b/server/internal/orchestrator/swarm/service_instance_spec.go index f613bac1..8514c8a6 100644 --- a/server/internal/orchestrator/swarm/service_instance_spec.go +++ b/server/internal/orchestrator/swarm/service_instance_spec.go @@ -104,19 +104,21 @@ func (s *ServiceInstanceSpecResource) Refresh(ctx context.Context, rc *resource. } spec, err := ServiceContainerSpec(&ServiceContainerSpecOptions{ - ServiceSpec: s.ServiceSpec, - ServiceInstanceID: s.ServiceInstanceID, - DatabaseID: s.DatabaseID, - DatabaseName: s.DatabaseName, - HostID: s.HostID, - ServiceName: s.ServiceName, - Hostname: s.Hostname, - CohortMemberID: s.CohortMemberID, - ServiceImage: s.ServiceImage, - Credentials: s.Credentials, - DatabaseNetworkID: network.NetworkID, - Port: s.Port, - DataPath: dataPath, + ServiceSpec: s.ServiceSpec, + ServiceInstanceID: s.ServiceInstanceID, + DatabaseID: s.DatabaseID, + DatabaseName: s.DatabaseName, + HostID: s.HostID, + ServiceName: s.ServiceName, + Hostname: s.Hostname, + CohortMemberID: s.CohortMemberID, + ServiceImage: s.ServiceImage, + Credentials: s.Credentials, + DatabaseNetworkID: network.NetworkID, + DatabaseHosts: s.DatabaseHosts, + TargetSessionAttrs: s.TargetSessionAttrs, + Port: s.Port, + DataPath: dataPath, }) if err != nil { return fmt.Errorf("failed to generate service container spec: %w", err) diff --git a/server/internal/orchestrator/swarm/service_spec.go b/server/internal/orchestrator/swarm/service_spec.go index f566fa0f..5aba4211 100644 --- a/server/internal/orchestrator/swarm/service_spec.go +++ b/server/internal/orchestrator/swarm/service_spec.go @@ -2,6 +2,8 @@ package swarm import ( "fmt" + "strconv" + "strings" "time" "github.com/docker/docker/api/types/container" @@ -15,19 +17,53 @@ import ( // mcpContainerUID is the UID of the MCP container user. const mcpContainerUID = 1001 +// postgrestContainerUID is the UID of the PostgREST container user. +// See: https://github.com/PostgREST/postgrest/blob/main/Dockerfile (USER 1000) +const postgrestContainerUID = 1000 + +func buildPostgRESTEnvVars(opts *ServiceContainerSpecOptions) []string { + hosts := make([]string, 0, len(opts.DatabaseHosts)) + ports := make([]string, 0, len(opts.DatabaseHosts)) + for _, h := range opts.DatabaseHosts { + hosts = append(hosts, h.Host) + ports = append(ports, strconv.Itoa(h.Port)) + } + env := []string{ + "PGRST_DB_URI=postgresql://", + "PGRST_SERVER_HOST=0.0.0.0", + "PGRST_SERVER_PORT=8080", + "PGRST_ADMIN_SERVER_PORT=3001", + fmt.Sprintf("PGHOST=%s", strings.Join(hosts, ",")), + fmt.Sprintf("PGPORT=%s", strings.Join(ports, ",")), + fmt.Sprintf("PGDATABASE=%s", opts.DatabaseName), + } + if opts.TargetSessionAttrs != "" { + env = append(env, fmt.Sprintf("PGTARGETSESSIONATTRS=%s", opts.TargetSessionAttrs)) + } + if opts.Credentials != nil { + env = append(env, + fmt.Sprintf("PGUSER=%s", opts.Credentials.Username), + fmt.Sprintf("PGPASSWORD=%s", opts.Credentials.Password), + ) + } + return env +} + // ServiceContainerSpecOptions contains all parameters needed to build a service container spec. type ServiceContainerSpecOptions struct { - ServiceSpec *database.ServiceSpec - ServiceInstanceID string - DatabaseID string - DatabaseName string - HostID string - ServiceName string - Hostname string - CohortMemberID string - ServiceImage *ServiceImage - Credentials *database.ServiceUser - DatabaseNetworkID string + ServiceSpec *database.ServiceSpec + ServiceInstanceID string + DatabaseID string + DatabaseName string + HostID string + ServiceName string + Hostname string + CohortMemberID string + ServiceImage *ServiceImage + Credentials *database.ServiceUser + DatabaseNetworkID string + DatabaseHosts []database.ServiceHostEntry // Ordered Postgres host:port entries + TargetSessionAttrs string // libpq target_session_attrs // Service port configuration Port *int // DataPath is the host-side directory path for the bind mount @@ -88,29 +124,64 @@ func ServiceContainerSpec(opts *ServiceContainerSpecOptions) (swarm.ServiceSpec, } } - // Build bind mount for config/auth files - mounts := []mount.Mount{ - docker.BuildMount(opts.DataPath, "/app/data", false), + // Build the container-spec fields that vary by service type. + var ( + command []string + args []string + env []string + user string + healthcheck *container.HealthConfig + mounts []mount.Mount + ) + + switch opts.ServiceSpec.ServiceType { + case "postgrest": + user = fmt.Sprintf("%d", postgrestContainerUID) + command = []string{"postgrest"} + args = []string{"/app/data/postgrest.conf"} + env = buildPostgRESTEnvVars(opts) + // postgrest --ready exits 0/1; no curl in the static binary image. + healthcheck = &container.HealthConfig{ + Test: []string{"CMD", "postgrest", "--ready"}, + StartPeriod: time.Second * 30, + Interval: time.Second * 10, + Timeout: time.Second * 5, + Retries: 3, + } + mounts = []mount.Mount{ + docker.BuildMount(opts.DataPath, "/app/data", true), + } + case "mcp": + user = fmt.Sprintf("%d", mcpContainerUID) + // Override the default container entrypoint to specify config path on bind mount. + command = []string{"/app/pgedge-postgres-mcp"} + args = []string{"-config", "/app/data/config.yaml"} + healthcheck = &container.HealthConfig{ + Test: []string{"CMD-SHELL", "curl -f http://localhost:8080/health || exit 1"}, + StartPeriod: time.Second * 30, + Interval: time.Second * 10, + Timeout: time.Second * 5, + Retries: 3, + } + mounts = []mount.Mount{ + docker.BuildMount(opts.DataPath, "/app/data", false), + } + default: + return swarm.ServiceSpec{}, fmt.Errorf("unsupported service type: %q", opts.ServiceSpec.ServiceType) } return swarm.ServiceSpec{ TaskTemplate: swarm.TaskSpec{ ContainerSpec: &swarm.ContainerSpec{ - Image: image, - Labels: labels, - Hostname: opts.Hostname, - User: fmt.Sprintf("%d", mcpContainerUID), - // override the default container entrypoint so we can specify path to config on bind mount - Command: []string{"/app/pgedge-postgres-mcp"}, - Args: []string{"-config", "/app/data/config.yaml"}, - Healthcheck: &container.HealthConfig{ - Test: []string{"CMD-SHELL", "curl -f http://localhost:8080/health || exit 1"}, - StartPeriod: time.Second * 30, - Interval: time.Second * 10, - Timeout: time.Second * 5, - Retries: 3, - }, - Mounts: mounts, + Image: image, + Labels: labels, + Hostname: opts.Hostname, + User: user, + Command: command, + Args: args, + Env: env, + Healthcheck: healthcheck, + Mounts: mounts, }, Networks: networks, Placement: &swarm.Placement{ diff --git a/server/internal/orchestrator/swarm/service_spec_test.go b/server/internal/orchestrator/swarm/service_spec_test.go index a9d5eac5..49aea7e7 100644 --- a/server/internal/orchestrator/swarm/service_spec_test.go +++ b/server/internal/orchestrator/swarm/service_spec_test.go @@ -2,6 +2,7 @@ package swarm import ( "fmt" + "strings" "testing" "github.com/docker/docker/api/types/swarm" @@ -307,3 +308,122 @@ func TestBuildServicePortConfig(t *testing.T) { func intPtr(i int) *int { return &i } + +// --- PostgREST container spec tests --- + +func makePostgRESTSpecOpts() *ServiceContainerSpecOptions { + return &ServiceContainerSpecOptions{ + ServiceSpec: &database.ServiceSpec{ + ServiceID: "svc-1", + ServiceType: "postgrest", + }, + ServiceInstanceID: "inst-1", + DatabaseID: "db-1", + DatabaseName: "mydb", + HostID: "host-1", + ServiceName: "svc-mydb-postgrest", + Hostname: "postgrest-host1", + CohortMemberID: "node-abc", + ServiceImage: &ServiceImage{Tag: "postgrest/postgrest:latest"}, + Credentials: &database.ServiceUser{ + Username: "svc_postgrest_host1", + Password: "supersecret", + }, + DatabaseNetworkID: "net-1", + DatabaseHosts: []database.ServiceHostEntry{{Host: "pg-host1", Port: 5432}}, + DataPath: "/var/lib/pgedge/services/inst-1", + } +} + +func TestServiceContainerSpec_PostgREST_Command(t *testing.T) { + spec, err := ServiceContainerSpec(makePostgRESTSpecOpts()) + if err != nil { + t.Fatalf("ServiceContainerSpec() error = %v", err) + } + cs := spec.TaskTemplate.ContainerSpec + if len(cs.Command) != 1 || cs.Command[0] != "postgrest" { + t.Errorf("Command = %v, want [\"postgrest\"]", cs.Command) + } + if len(cs.Args) != 1 || cs.Args[0] != "/app/data/postgrest.conf" { + t.Errorf("Args = %v, want [\"/app/data/postgrest.conf\"]", cs.Args) + } +} + +func TestServiceContainerSpec_PostgREST_HealthCheck(t *testing.T) { + spec, err := ServiceContainerSpec(makePostgRESTSpecOpts()) + if err != nil { + t.Fatalf("ServiceContainerSpec() error = %v", err) + } + hc := spec.TaskTemplate.ContainerSpec.Healthcheck + if hc == nil { + t.Fatal("Healthcheck is nil") + } + want := []string{"CMD", "postgrest", "--ready"} + if len(hc.Test) != len(want) { + t.Fatalf("Healthcheck.Test = %v, want %v", hc.Test, want) + } + for i, v := range want { + if hc.Test[i] != v { + t.Errorf("Healthcheck.Test[%d] = %q, want %q", i, hc.Test[i], v) + } + } +} + +func TestServiceContainerSpec_PostgREST_EnvVars(t *testing.T) { + spec, err := ServiceContainerSpec(makePostgRESTSpecOpts()) + if err != nil { + t.Fatalf("ServiceContainerSpec() error = %v", err) + } + envMap := make(map[string]string) + for _, e := range spec.TaskTemplate.ContainerSpec.Env { + parts := strings.SplitN(e, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = parts[1] + } + } + checks := map[string]string{ + "PGRST_DB_URI": "postgresql://", + "PGRST_SERVER_PORT": "8080", + "PGRST_ADMIN_SERVER_PORT": "3001", + "PGHOST": "pg-host1", + "PGPORT": "5432", + "PGDATABASE": "mydb", + "PGUSER": "svc_postgrest_host1", + "PGPASSWORD": "supersecret", + } + for key, want := range checks { + if got, ok := envMap[key]; !ok { + t.Errorf("env var %s is missing", key) + } else if got != want { + t.Errorf("env var %s = %q, want %q", key, got, want) + } + } +} + +func TestServiceContainerSpec_PostgREST_MountReadOnly(t *testing.T) { + spec, err := ServiceContainerSpec(makePostgRESTSpecOpts()) + if err != nil { + t.Fatalf("ServiceContainerSpec() error = %v", err) + } + mounts := spec.TaskTemplate.ContainerSpec.Mounts + if len(mounts) != 1 { + t.Fatalf("len(Mounts) = %d, want 1", len(mounts)) + } + if !mounts[0].ReadOnly { + t.Error("data mount should be read-only for PostgREST") + } + if mounts[0].Target != "/app/data" { + t.Errorf("mount target = %q, want \"/app/data\"", mounts[0].Target) + } +} + +func TestServiceContainerSpec_PostgREST_User(t *testing.T) { + spec, err := ServiceContainerSpec(makePostgRESTSpecOpts()) + if err != nil { + t.Fatalf("ServiceContainerSpec() error = %v", err) + } + want := fmt.Sprintf("%d", postgrestContainerUID) + if spec.TaskTemplate.ContainerSpec.User != want { + t.Errorf("User = %q, want %q (PostgREST runs as UID 1000 per official Dockerfile)", spec.TaskTemplate.ContainerSpec.User, want) + } +} diff --git a/server/internal/orchestrator/swarm/service_user_role.go b/server/internal/orchestrator/swarm/service_user_role.go index 7724690b..019d8d01 100644 --- a/server/internal/orchestrator/swarm/service_user_role.go +++ b/server/internal/orchestrator/swarm/service_user_role.go @@ -65,8 +65,10 @@ type ServiceUserRole struct { ServiceID string `json:"service_id"` DatabaseID string `json:"database_id"` DatabaseName string `json:"database_name"` - NodeName string `json:"node_name"` // Database node name for PrimaryExecutor routing - Mode string `json:"mode"` // ServiceUserRoleRO or ServiceUserRoleRW + NodeName string `json:"node_name"` // Database node name for PrimaryExecutor routing + Mode string `json:"mode"` // ServiceUserRoleRO or ServiceUserRoleRW + ServiceType string `json:"service_type"` // "mcp" or "postgrest" + DBAnonRole string `json:"db_anon_role"` // PostgREST only: anonymous role granted to the service user Username string `json:"username"` Password string `json:"password"` // Generated on Create, persisted in state CredentialSource *resource.Identifier `json:"credential_source,omitempty"` @@ -166,35 +168,68 @@ func (r *ServiceUserRole) createUserRole(ctx context.Context, rc *resource.Conte } defer conn.Close(ctx) - // Determine group role based on mode - var groupRole string - switch r.Mode { - case ServiceUserRoleRO: - groupRole = "pgedge_application_read_only" - case ServiceUserRoleRW: - groupRole = "pgedge_application" - default: - return fmt.Errorf("unknown service user role mode: %q", r.Mode) - } - - statements, err := postgres.CreateUserRole(postgres.UserRoleOptions{ - Name: r.Username, - Password: r.Password, - DBOwner: false, - Attributes: []string{"LOGIN"}, - Roles: []string{groupRole}, - }) - if err != nil { - return fmt.Errorf("failed to generate create user role statements: %w", err) - } - - if err := statements.Exec(ctx, conn); err != nil { - return fmt.Errorf("failed to create service user: %w", err) + if r.ServiceType == "postgrest" { + attributes, grants := r.roleAttributesAndGrants() + statements, err := postgres.CreateUserRole(postgres.UserRoleOptions{ + Name: r.Username, + Password: r.Password, + DBOwner: false, + Attributes: attributes, + }) + if err != nil { + return fmt.Errorf("failed to generate create user role statements: %w", err) + } + if err := statements.Exec(ctx, conn); err != nil { + return fmt.Errorf("failed to create service user: %w", err) + } + if err := grants.Exec(ctx, conn); err != nil { + return fmt.Errorf("failed to grant service user permissions: %w", err) + } + } else { + var groupRole string + switch r.Mode { + case ServiceUserRoleRO: + groupRole = "pgedge_application_read_only" + case ServiceUserRoleRW: + groupRole = "pgedge_application" + default: + return fmt.Errorf("unknown service user role mode: %q", r.Mode) + } + statements, err := postgres.CreateUserRole(postgres.UserRoleOptions{ + Name: r.Username, + Password: r.Password, + DBOwner: false, + Attributes: []string{"LOGIN"}, + Roles: []string{groupRole}, + }) + if err != nil { + return fmt.Errorf("failed to generate create user role statements: %w", err) + } + if err := statements.Exec(ctx, conn); err != nil { + return fmt.Errorf("failed to create service user: %w", err) + } } return nil } +// roleAttributesAndGrants returns the PostgREST-specific role attributes and +// SQL grant statements. Only called when ServiceType == "postgrest"; +// MCP uses the group-role path in createUserRole() directly. +func (r *ServiceUserRole) roleAttributesAndGrants() ([]string, postgres.Statements) { + // NOINHERIT + GRANT enables PostgREST's SET ROLE mechanism. + attributes := []string{"LOGIN", "NOINHERIT"} + anonRole := r.DBAnonRole + if anonRole == "" { + anonRole = "pgedge_application_read_only" + } + grants := postgres.Statements{ + postgres.Statement{SQL: fmt.Sprintf("GRANT CONNECT ON DATABASE %s TO %s;", sanitizeIdentifier(r.DatabaseName), sanitizeIdentifier(r.Username))}, + postgres.Statement{SQL: fmt.Sprintf("GRANT %s TO %s;", sanitizeIdentifier(anonRole), sanitizeIdentifier(r.Username))}, + } + return attributes, grants +} + func (r *ServiceUserRole) Update(ctx context.Context, rc *resource.Context) error { // Service users don't support updates (no credential rotation in Phase 1) return nil diff --git a/server/internal/orchestrator/swarm/service_user_role_test.go b/server/internal/orchestrator/swarm/service_user_role_test.go index ab23f0c1..735f8fc3 100644 --- a/server/internal/orchestrator/swarm/service_user_role_test.go +++ b/server/internal/orchestrator/swarm/service_user_role_test.go @@ -2,9 +2,11 @@ package swarm import ( "fmt" + "strings" "testing" "github.com/pgEdge/control-plane/server/internal/database" + "github.com/pgEdge/control-plane/server/internal/postgres" "github.com/pgEdge/control-plane/server/internal/resource" ) @@ -209,3 +211,99 @@ func TestServiceUserRolePerNodeIdentifierUniqueness(t *testing.T) { seen[id] = fmt.Sprintf("role[%d] node=%s mode=%s", i, r.NodeName, r.Mode) } } + +// statementsSQL extracts the raw SQL from a postgres.Statements slice. +func statementsSQL(stmts postgres.Statements) []string { + out := make([]string, 0, len(stmts)) + for i, s := range stmts { + if stmt, ok := s.(postgres.Statement); ok { + out = append(out, stmt.SQL) + continue + } + panic(fmt.Sprintf("statementsSQL: unexpected statement type %T at index %d", s, i)) + } + return out +} + +func joinSQL(stmts postgres.Statements) string { + return strings.Join(statementsSQL(stmts), "\n") +} + +func TestRoleAttributesAndGrants_PostgREST_Attributes(t *testing.T) { + r := &ServiceUserRole{ + ServiceType: "postgrest", + DatabaseName: "mydb", + Username: "svc_pgrest", + DBAnonRole: "web_anon", + } + attrs, _ := r.roleAttributesAndGrants() + + attrSet := make(map[string]bool) + for _, a := range attrs { + attrSet[a] = true + } + if !attrSet["LOGIN"] { + t.Error("PostgREST attributes must include LOGIN") + } + if !attrSet["NOINHERIT"] { + t.Error("PostgREST attributes must include NOINHERIT") + } +} + +func TestRoleAttributesAndGrants_PostgREST_GrantsAnonRole(t *testing.T) { + r := &ServiceUserRole{ + ServiceType: "postgrest", + DatabaseName: "mydb", + Username: "svc_pgrest", + DBAnonRole: "web_anon", + } + _, grants := r.roleAttributesAndGrants() + sql := joinSQL(grants) + + if !strings.Contains(sql, "GRANT CONNECT") { + t.Errorf("PostgREST grants missing GRANT CONNECT\nGot:\n%s", sql) + } + if !strings.Contains(sql, `"web_anon"`) { + t.Errorf("PostgREST grants must grant configured DBAnonRole\nGot:\n%s", sql) + } +} + +func TestRoleAttributesAndGrants_PostgREST_DefaultAnonRole(t *testing.T) { + // Empty DBAnonRole → default to pgedge_application_read_only + r := &ServiceUserRole{ + ServiceType: "postgrest", + DatabaseName: "mydb", + Username: "svc_pgrest", + DBAnonRole: "", + } + _, grants := r.roleAttributesAndGrants() + sql := joinSQL(grants) + + if !strings.Contains(sql, `"pgedge_application_read_only"`) { + t.Errorf("PostgREST must default DBAnonRole to pgedge_application_read_only\nGot:\n%s", sql) + } +} + +func TestRoleAttributesAndGrants_PostgREST_NoDirectTableGrants(t *testing.T) { + // PostgREST accesses tables via the anon role — no direct table grants. + r := &ServiceUserRole{ + ServiceType: "postgrest", + DatabaseName: "mydb", + Username: "svc_pgrest", + DBAnonRole: "web_anon", + } + _, grants := r.roleAttributesAndGrants() + sql := joinSQL(grants) + + for _, forbidden := range []string{ + "GRANT SELECT", + "GRANT USAGE ON SCHEMA", + "ALTER DEFAULT PRIVILEGES", + "pg_read_all_settings", + } { + if strings.Contains(sql, forbidden) { + t.Errorf("PostgREST grants must not include %q (accesses tables via anon role)\nGot:\n%s", forbidden, sql) + } + } +} +