diff --git a/internal/migration_acceptance_tests/view_privilege_cases_test.go b/internal/migration_acceptance_tests/view_privilege_cases_test.go new file mode 100644 index 0000000..5f9cdd9 --- /dev/null +++ b/internal/migration_acceptance_tests/view_privilege_cases_test.go @@ -0,0 +1,160 @@ +package migration_acceptance_tests + +import ( + "testing" + + "github.com/stripe/pg-schema-diff/pkg/diff" +) + +var viewPrivilegeAcceptanceTestCases = []acceptanceTestCase{ + { + name: "no-op: view with existing privilege", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + CREATE VIEW foobar_view AS SELECT id FROM foobar; + GRANT SELECT ON foobar_view TO app_user; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + CREATE VIEW foobar_view AS SELECT id FROM foobar; + GRANT SELECT ON foobar_view TO app_user; + `, + }, + expectEmptyPlan: true, + }, + { + name: "Grant SELECT on view", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + CREATE VIEW foobar_view AS SELECT id FROM foobar; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + CREATE VIEW foobar_view AS SELECT id FROM foobar; + GRANT SELECT ON foobar_view TO app_user; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Revoke SELECT on view", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + CREATE VIEW foobar_view AS SELECT id FROM foobar; + GRANT SELECT ON foobar_view TO app_user; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + CREATE VIEW foobar_view AS SELECT id FROM foobar; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Grant multiple privileges on view", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + CREATE VIEW foobar_view AS SELECT id FROM foobar; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + CREATE VIEW foobar_view AS SELECT id FROM foobar; + GRANT SELECT, INSERT, UPDATE, DELETE ON foobar_view TO app_user; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Grant WITH GRANT OPTION on view", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + CREATE VIEW foobar_view AS SELECT id FROM foobar; + `, + }, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + CREATE VIEW foobar_view AS SELECT id FROM foobar; + GRANT SELECT ON foobar_view TO app_user WITH GRANT OPTION; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Grant on non-public schema view", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + ` + CREATE SCHEMA app_schema; + CREATE TABLE app_schema.foobar(id INT); + CREATE VIEW app_schema.foobar_view AS SELECT id FROM app_schema.foobar; + `, + }, + newSchemaDDL: []string{ + ` + CREATE SCHEMA app_schema; + CREATE TABLE app_schema.foobar(id INT); + CREATE VIEW app_schema.foobar_view AS SELECT id FROM app_schema.foobar; + GRANT SELECT ON app_schema.foobar_view TO app_user; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{ + diff.MigrationHazardTypeAuthzUpdate, + }, + }, + { + name: "Grant on new view (no hazards since view is new)", + roles: []string{"app_user"}, + newSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + CREATE VIEW foobar_view AS SELECT id FROM foobar; + GRANT SELECT ON foobar_view TO app_user; + `, + }, + }, + { + name: "Drop view with privileges", + roles: []string{"app_user"}, + oldSchemaDDL: []string{ + ` + CREATE TABLE foobar(id INT); + CREATE VIEW foobar_view AS SELECT id FROM foobar; + GRANT SELECT ON foobar_view TO app_user; + `, + }, + newSchemaDDL: []string{ + `CREATE TABLE foobar(id INT);`, + }, + }, +} + +func TestViewPrivilegeCases(t *testing.T) { + runTestCases(t, viewPrivilegeAcceptanceTestCases) +} diff --git a/internal/queries/queries.sql b/internal/queries/queries.sql index d8397c8..8cab585 100644 --- a/internal/queries/queries.sql +++ b/internal/queries/queries.sql @@ -641,9 +641,9 @@ WITH parsed_acl AS ( n.nspname NOT IN ('pg_catalog', 'information_schema') AND n.nspname !~ '^pg_toast' AND n.nspname !~ '^pg_temp' - AND (c.relkind = 'r' OR c.relkind = 'p') + AND (c.relkind = 'r' OR c.relkind = 'p' OR c.relkind = 'v') AND c.relacl IS NOT null - -- Exclude tables owned by extensions + -- Exclude tables/views owned by extensions AND NOT EXISTS ( SELECT depend.objid FROM pg_catalog.pg_depend AS depend diff --git a/internal/queries/queries.sql.go b/internal/queries/queries.sql.go index 4315102..18159d6 100644 --- a/internal/queries/queries.sql.go +++ b/internal/queries/queries.sql.go @@ -1020,9 +1020,9 @@ WITH parsed_acl AS ( n.nspname NOT IN ('pg_catalog', 'information_schema') AND n.nspname !~ '^pg_toast' AND n.nspname !~ '^pg_temp' - AND (c.relkind = 'r' OR c.relkind = 'p') + AND (c.relkind = 'r' OR c.relkind = 'p' OR c.relkind = 'v') AND c.relacl IS NOT null - -- Exclude tables owned by extensions + -- Exclude tables/views owned by extensions AND NOT EXISTS ( SELECT depend.objid FROM pg_catalog.pg_depend AS depend diff --git a/internal/schema/schema.go b/internal/schema/schema.go index b3945d5..7b42f98 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -145,6 +145,9 @@ func normalizeView(v View) View { normTableDeps = append(normTableDeps, d) } v.TableDependencies = normTableDeps + + v.Privileges = sortSchemaObjectsByName(v.Privileges) + return v } @@ -531,6 +534,7 @@ type View struct { // TableDependencies is a list of tables the view depends on. TableDependencies []TableDependency + Privileges []TablePrivilege } type MaterializedView struct { @@ -1491,6 +1495,15 @@ func (s *schemaFetcher) fetchViews(ctx context.Context) ([]View, error) { return nil, fmt.Errorf("GetViews: %w", err) } + privileges, err := s.fetchPrivileges(ctx) + if err != nil { + return nil, fmt.Errorf("fetchPrivileges(): %w", err) + } + privilegesByView := make(map[string][]TablePrivilege) + for _, p := range privileges { + privilegesByView[p.table.GetFQEscapedName()] = append(privilegesByView[p.table.GetFQEscapedName()], p.privilege) + } + var views []View for _, v := range rawViews { options, err := relOptionsToMap(v.RelOptions) @@ -1503,12 +1516,14 @@ func (s *schemaFetcher) fetchViews(ctx context.Context) ([]View, error) { return nil, fmt.Errorf("parsing schema qualified names JSON: %w", err) } + schemaQualifiedName := buildNameFromUnescaped(v.ViewName, v.SchemaName) views = append(views, View{ - SchemaQualifiedName: buildNameFromUnescaped(v.ViewName, v.SchemaName), + SchemaQualifiedName: schemaQualifiedName, ViewDefinition: v.ViewDefinition, Options: options, TableDependencies: tableDependencies, + Privileges: privilegesByView[schemaQualifiedName.GetFQEscapedName()], }) } diff --git a/pkg/diff/view_sql_generator.go b/pkg/diff/view_sql_generator.go index d96f28d..44d8ba8 100644 --- a/pkg/diff/view_sql_generator.go +++ b/pkg/diff/view_sql_generator.go @@ -13,6 +13,7 @@ import ( type viewDiff struct { oldAndNew[schema.View] + privilegesDiff listDiff[schema.TablePrivilege, privilegeDiff] } func buildViewDiff( @@ -56,8 +57,24 @@ func buildViewDiff( } } + privilegesDiff, err := diffLists( + old.Privileges, + new.Privileges, + func(old, new schema.TablePrivilege, _, _ int) (privilegeDiff, bool, error) { + // Recreate the privilege if IsGrantable changes + recreate := old.IsGrantable != new.IsGrantable + return privilegeDiff{oldAndNew[schema.TablePrivilege]{old: old, new: new}}, recreate, nil + }, + ) + if err != nil { + return viewDiff{}, false, fmt.Errorf("diffing privileges: %w", err) + } + // Recreate if the view SQL generator cannot alter the view. - d := viewDiff{oldAndNew: oldAndNew[schema.View]{old: old, new: new}} + d := viewDiff{ + oldAndNew: oldAndNew[schema.View]{old: old, new: new}, + privilegesDiff: privilegesDiff, + } if _, err := newViewSQLVertexGenerator().Alter(d); err != nil { if errors.Is(err, ErrNotImplemented) { // The SQL generator cannot alter the view, so add and delete it. @@ -91,6 +108,22 @@ func (vsg *viewSQLGenerator) Add(v schema.View) (partialSQLGraph, error) { viewSb.WriteString(" AS\n") viewSb.WriteString(v.ViewDefinition) + stmts := []Statement{{ + DDL: viewSb.String(), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + }} + + privilegeGenerator := &privilegeSQLVertexGenerator{tableName: v.SchemaQualifiedName} + for _, privilege := range v.Privileges { + addPrivilegeStmts, err := privilegeGenerator.Add(privilege) + if err != nil { + return partialSQLGraph{}, fmt.Errorf("generating add privilege statements for privilege %s: %w", privilege.GetName(), err) + } + // Remove hazards from statements since the view is brand new + stmts = append(stmts, stripMigrationHazards(addPrivilegeStmts...)...) + } + addVertexId := buildTableVertexId(v.SchemaQualifiedName, diffTypeAddAlter) var deps []dependency @@ -106,13 +139,9 @@ func (vsg *viewSQLGenerator) Add(v schema.View) (partialSQLGraph, error) { return partialSQLGraph{ vertices: []sqlVertex{{ - id: addVertexId, - priority: sqlPrioritySooner, - statements: []Statement{{ - DDL: viewSb.String(), - Timeout: statementTimeoutDefault, - LockTimeout: lockTimeoutDefault, - }}, + id: addVertexId, + priority: sqlPrioritySooner, + statements: stmts, }}, dependencies: deps, }, nil @@ -143,11 +172,23 @@ func (vsg *viewSQLGenerator) Delete(v schema.View) (partialSQLGraph, error) { } func (vsg *viewSQLGenerator) Alter(vd viewDiff) (partialSQLGraph, error) { - // In the initial MVP, we will not support altering. - if !cmp.Equal(vd.old, vd.new) { + // Compare old and new views ignoring the Privileges field, which is handled separately. + oldWithoutPrivileges := vd.old + oldWithoutPrivileges.Privileges = nil + newWithoutPrivileges := vd.new + newWithoutPrivileges.Privileges = nil + + if !cmp.Equal(oldWithoutPrivileges, newWithoutPrivileges) { return partialSQLGraph{}, ErrNotImplemented } - return partialSQLGraph{}, nil + + privilegeGenerator := newPrivilegeSQLVertexGenerator(vd.new.SchemaQualifiedName) + privilegesPartialGraph, err := generatePartialGraph(privilegeGenerator, vd.privilegesDiff) + if err != nil { + return partialSQLGraph{}, fmt.Errorf("resolving privilege sql: %w", err) + } + + return privilegesPartialGraph, nil } func buildViewVertexId(n schema.SchemaQualifiedName, d diffType) sqlVertexId {