Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions internal/migration_acceptance_tests/view_privilege_cases_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package migration_acceptance_tests

import (
"testing"

"github.com/stripe/pg-schema-diff/pkg/diff"
)

var viewPrivilegeAcceptanceTestCases = []acceptanceTestCase{
{
name: "no-op: view with existing privilege",
roles: []string{"app_user"},
oldSchemaDDL: []string{
`
CREATE TABLE foobar(id INT);
CREATE VIEW foobar_view AS SELECT id FROM foobar;
GRANT SELECT ON foobar_view TO app_user;
`,
},
newSchemaDDL: []string{
`
CREATE TABLE foobar(id INT);
CREATE VIEW foobar_view AS SELECT id FROM foobar;
GRANT SELECT ON foobar_view TO app_user;
`,
},
expectEmptyPlan: true,
},
{
name: "Grant SELECT on view",
roles: []string{"app_user"},
oldSchemaDDL: []string{
`
CREATE TABLE foobar(id INT);
CREATE VIEW foobar_view AS SELECT id FROM foobar;
`,
},
newSchemaDDL: []string{
`
CREATE TABLE foobar(id INT);
CREATE VIEW foobar_view AS SELECT id FROM foobar;
GRANT SELECT ON foobar_view TO app_user;
`,
},
expectedHazardTypes: []diff.MigrationHazardType{
diff.MigrationHazardTypeAuthzUpdate,
},
},
{
name: "Revoke SELECT on view",
roles: []string{"app_user"},
oldSchemaDDL: []string{
`
CREATE TABLE foobar(id INT);
CREATE VIEW foobar_view AS SELECT id FROM foobar;
GRANT SELECT ON foobar_view TO app_user;
`,
},
newSchemaDDL: []string{
`
CREATE TABLE foobar(id INT);
CREATE VIEW foobar_view AS SELECT id FROM foobar;
`,
},
expectedHazardTypes: []diff.MigrationHazardType{
diff.MigrationHazardTypeAuthzUpdate,
},
},
{
name: "Grant multiple privileges on view",
roles: []string{"app_user"},
oldSchemaDDL: []string{
`
CREATE TABLE foobar(id INT);
CREATE VIEW foobar_view AS SELECT id FROM foobar;
`,
},
newSchemaDDL: []string{
`
CREATE TABLE foobar(id INT);
CREATE VIEW foobar_view AS SELECT id FROM foobar;
GRANT SELECT, INSERT, UPDATE, DELETE ON foobar_view TO app_user;
`,
},
expectedHazardTypes: []diff.MigrationHazardType{
diff.MigrationHazardTypeAuthzUpdate,
},
},
{
name: "Grant WITH GRANT OPTION on view",
roles: []string{"app_user"},
oldSchemaDDL: []string{
`
CREATE TABLE foobar(id INT);
CREATE VIEW foobar_view AS SELECT id FROM foobar;
`,
},
newSchemaDDL: []string{
`
CREATE TABLE foobar(id INT);
CREATE VIEW foobar_view AS SELECT id FROM foobar;
GRANT SELECT ON foobar_view TO app_user WITH GRANT OPTION;
`,
},
expectedHazardTypes: []diff.MigrationHazardType{
diff.MigrationHazardTypeAuthzUpdate,
},
},
{
name: "Grant on non-public schema view",
roles: []string{"app_user"},
oldSchemaDDL: []string{
`
CREATE SCHEMA app_schema;
CREATE TABLE app_schema.foobar(id INT);
CREATE VIEW app_schema.foobar_view AS SELECT id FROM app_schema.foobar;
`,
},
newSchemaDDL: []string{
`
CREATE SCHEMA app_schema;
CREATE TABLE app_schema.foobar(id INT);
CREATE VIEW app_schema.foobar_view AS SELECT id FROM app_schema.foobar;
GRANT SELECT ON app_schema.foobar_view TO app_user;
`,
},
expectedHazardTypes: []diff.MigrationHazardType{
diff.MigrationHazardTypeAuthzUpdate,
},
},
{
name: "Grant on new view (no hazards since view is new)",
roles: []string{"app_user"},
newSchemaDDL: []string{
`
CREATE TABLE foobar(id INT);
CREATE VIEW foobar_view AS SELECT id FROM foobar;
GRANT SELECT ON foobar_view TO app_user;
`,
},
},
{
name: "Drop view with privileges",
roles: []string{"app_user"},
oldSchemaDDL: []string{
`
CREATE TABLE foobar(id INT);
CREATE VIEW foobar_view AS SELECT id FROM foobar;
GRANT SELECT ON foobar_view TO app_user;
`,
},
newSchemaDDL: []string{
`CREATE TABLE foobar(id INT);`,
},
},
}

func TestViewPrivilegeCases(t *testing.T) {
runTestCases(t, viewPrivilegeAcceptanceTestCases)
}
4 changes: 2 additions & 2 deletions internal/queries/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -641,9 +641,9 @@ WITH parsed_acl AS (
n.nspname NOT IN ('pg_catalog', 'information_schema')
AND n.nspname !~ '^pg_toast'
AND n.nspname !~ '^pg_temp'
AND (c.relkind = 'r' OR c.relkind = 'p')
AND (c.relkind = 'r' OR c.relkind = 'p' OR c.relkind = 'v')
AND c.relacl IS NOT null
-- Exclude tables owned by extensions
-- Exclude tables/views owned by extensions
AND NOT EXISTS (
SELECT depend.objid
FROM pg_catalog.pg_depend AS depend
Expand Down
4 changes: 2 additions & 2 deletions internal/queries/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 16 additions & 1 deletion internal/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ func normalizeView(v View) View {
normTableDeps = append(normTableDeps, d)
}
v.TableDependencies = normTableDeps

v.Privileges = sortSchemaObjectsByName(v.Privileges)

return v
}

Expand Down Expand Up @@ -531,6 +534,7 @@ type View struct {

// TableDependencies is a list of tables the view depends on.
TableDependencies []TableDependency
Privileges []TablePrivilege
}

type MaterializedView struct {
Expand Down Expand Up @@ -1491,6 +1495,15 @@ func (s *schemaFetcher) fetchViews(ctx context.Context) ([]View, error) {
return nil, fmt.Errorf("GetViews: %w", err)
}

privileges, err := s.fetchPrivileges(ctx)
if err != nil {
return nil, fmt.Errorf("fetchPrivileges(): %w", err)
}
privilegesByView := make(map[string][]TablePrivilege)
for _, p := range privileges {
privilegesByView[p.table.GetFQEscapedName()] = append(privilegesByView[p.table.GetFQEscapedName()], p.privilege)
}

var views []View
for _, v := range rawViews {
options, err := relOptionsToMap(v.RelOptions)
Expand All @@ -1503,12 +1516,14 @@ func (s *schemaFetcher) fetchViews(ctx context.Context) ([]View, error) {
return nil, fmt.Errorf("parsing schema qualified names JSON: %w", err)
}

schemaQualifiedName := buildNameFromUnescaped(v.ViewName, v.SchemaName)
views = append(views, View{
SchemaQualifiedName: buildNameFromUnescaped(v.ViewName, v.SchemaName),
SchemaQualifiedName: schemaQualifiedName,
ViewDefinition: v.ViewDefinition,
Options: options,

TableDependencies: tableDependencies,
Privileges: privilegesByView[schemaQualifiedName.GetFQEscapedName()],
})
}

Expand Down
63 changes: 52 additions & 11 deletions pkg/diff/view_sql_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

type viewDiff struct {
oldAndNew[schema.View]
privilegesDiff listDiff[schema.TablePrivilege, privilegeDiff]
}

func buildViewDiff(
Expand Down Expand Up @@ -56,8 +57,24 @@ func buildViewDiff(
}
}

privilegesDiff, err := diffLists(
old.Privileges,
new.Privileges,
func(old, new schema.TablePrivilege, _, _ int) (privilegeDiff, bool, error) {
// Recreate the privilege if IsGrantable changes
recreate := old.IsGrantable != new.IsGrantable
return privilegeDiff{oldAndNew[schema.TablePrivilege]{old: old, new: new}}, recreate, nil
},
)
if err != nil {
return viewDiff{}, false, fmt.Errorf("diffing privileges: %w", err)
}

// Recreate if the view SQL generator cannot alter the view.
d := viewDiff{oldAndNew: oldAndNew[schema.View]{old: old, new: new}}
d := viewDiff{
oldAndNew: oldAndNew[schema.View]{old: old, new: new},
privilegesDiff: privilegesDiff,
}
if _, err := newViewSQLVertexGenerator().Alter(d); err != nil {
if errors.Is(err, ErrNotImplemented) {
// The SQL generator cannot alter the view, so add and delete it.
Expand Down Expand Up @@ -91,6 +108,22 @@ func (vsg *viewSQLGenerator) Add(v schema.View) (partialSQLGraph, error) {
viewSb.WriteString(" AS\n")
viewSb.WriteString(v.ViewDefinition)

stmts := []Statement{{
DDL: viewSb.String(),
Timeout: statementTimeoutDefault,
LockTimeout: lockTimeoutDefault,
}}

privilegeGenerator := &privilegeSQLVertexGenerator{tableName: v.SchemaQualifiedName}
for _, privilege := range v.Privileges {
addPrivilegeStmts, err := privilegeGenerator.Add(privilege)
if err != nil {
return partialSQLGraph{}, fmt.Errorf("generating add privilege statements for privilege %s: %w", privilege.GetName(), err)
}
// Remove hazards from statements since the view is brand new
stmts = append(stmts, stripMigrationHazards(addPrivilegeStmts...)...)
}

addVertexId := buildTableVertexId(v.SchemaQualifiedName, diffTypeAddAlter)

var deps []dependency
Expand All @@ -106,13 +139,9 @@ func (vsg *viewSQLGenerator) Add(v schema.View) (partialSQLGraph, error) {

return partialSQLGraph{
vertices: []sqlVertex{{
id: addVertexId,
priority: sqlPrioritySooner,
statements: []Statement{{
DDL: viewSb.String(),
Timeout: statementTimeoutDefault,
LockTimeout: lockTimeoutDefault,
}},
id: addVertexId,
priority: sqlPrioritySooner,
statements: stmts,
}},
dependencies: deps,
}, nil
Expand Down Expand Up @@ -143,11 +172,23 @@ func (vsg *viewSQLGenerator) Delete(v schema.View) (partialSQLGraph, error) {
}

func (vsg *viewSQLGenerator) Alter(vd viewDiff) (partialSQLGraph, error) {
// In the initial MVP, we will not support altering.
if !cmp.Equal(vd.old, vd.new) {
// Compare old and new views ignoring the Privileges field, which is handled separately.
oldWithoutPrivileges := vd.old
oldWithoutPrivileges.Privileges = nil
newWithoutPrivileges := vd.new
newWithoutPrivileges.Privileges = nil

if !cmp.Equal(oldWithoutPrivileges, newWithoutPrivileges) {
return partialSQLGraph{}, ErrNotImplemented
}
return partialSQLGraph{}, nil

privilegeGenerator := newPrivilegeSQLVertexGenerator(vd.new.SchemaQualifiedName)
privilegesPartialGraph, err := generatePartialGraph(privilegeGenerator, vd.privilegesDiff)
if err != nil {
return partialSQLGraph{}, fmt.Errorf("resolving privilege sql: %w", err)
}

return privilegesPartialGraph, nil
}

func buildViewVertexId(n schema.SchemaQualifiedName, d diffType) sqlVertexId {
Expand Down