Skip to content

Commit 6b262de

Browse files
committed
mongo: pr review
1 parent bada7e6 commit 6b262de

File tree

2 files changed

+16
-22
lines changed

2 files changed

+16
-22
lines changed

flow/shared/mongo/commands.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ type BuildInfo struct {
1212
Version string `bson:"version"`
1313
}
1414

15-
func GetBuildInfo(ctx context.Context, client *mongo.Client) (*BuildInfo, error) {
15+
func GetBuildInfo(ctx context.Context, client *mongo.Client) (BuildInfo, error) {
1616
return runCommand[BuildInfo](ctx, client, "buildInfo")
1717
}
1818

@@ -21,7 +21,7 @@ type ReplSetGetStatus struct {
2121
MyState int `bson:"myState"`
2222
}
2323

24-
func GetReplSetGetStatus(ctx context.Context, client *mongo.Client) (*ReplSetGetStatus, error) {
24+
func GetReplSetGetStatus(ctx context.Context, client *mongo.Client) (ReplSetGetStatus, error) {
2525
return runCommand[ReplSetGetStatus](ctx, client, "replSetGetStatus")
2626
}
2727

@@ -38,7 +38,7 @@ type ServerStatus struct {
3838
OplogTruncation OplogTruncation `bson:"oplogTruncation"`
3939
}
4040

41-
func GetServerStatus(ctx context.Context, client *mongo.Client) (*ServerStatus, error) {
41+
func GetServerStatus(ctx context.Context, client *mongo.Client) (ServerStatus, error) {
4242
return runCommand[ServerStatus](ctx, client, "serverStatus")
4343
}
4444

@@ -55,7 +55,7 @@ type Role struct {
5555
DB string `bson:"db"`
5656
}
5757

58-
func GetConnectionStatus(ctx context.Context, client *mongo.Client) (*ConnectionStatus, error) {
58+
func GetConnectionStatus(ctx context.Context, client *mongo.Client) (ConnectionStatus, error) {
5959
return runCommand[ConnectionStatus](ctx, client, "connectionStatus")
6060
}
6161

@@ -64,21 +64,21 @@ type HelloResponse struct {
6464
Hosts []string `bson:"hosts,omitempty"`
6565
}
6666

67-
func GetHelloResponse(ctx context.Context, client *mongo.Client) (*HelloResponse, error) {
67+
func GetHelloResponse(ctx context.Context, client *mongo.Client) (HelloResponse, error) {
6868
return runCommand[HelloResponse](ctx, client, "hello")
6969
}
7070

71-
func runCommand[T any](ctx context.Context, client *mongo.Client, command string) (*T, error) {
71+
func runCommand[T any](ctx context.Context, client *mongo.Client, command string) (T, error) {
72+
var result T
7273
singleResult := client.Database("admin").RunCommand(ctx, bson.D{
7374
bson.E{Key: command, Value: 1},
7475
})
7576
if singleResult.Err() != nil {
76-
return nil, fmt.Errorf("'%s' failed: %v", command, singleResult.Err())
77+
return result, fmt.Errorf("'%s' failed: %v", command, singleResult.Err())
7778
}
7879

79-
var result T
8080
if err := singleResult.Decode(&result); err != nil {
81-
return nil, fmt.Errorf("'%s' failed: %v", command, err)
81+
return result, fmt.Errorf("'%s' failed: %v", command, err)
8282
}
83-
return &result, nil
83+
return result, nil
8484
}

flow/shared/mongo/validation.go

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"slices"
78

89
"go.mongodb.org/mongo-driver/v2/mongo"
910
)
@@ -16,6 +17,8 @@ const (
1617
ShardedCluster = "ShardedCluster"
1718
)
1819

20+
var RequiredRoles = [...]string{"readAnyDatabase", "clusterMonitor"}
21+
1922
func ValidateServerCompatibility(ctx context.Context, client *mongo.Client) error {
2023
buildInfo, err := GetBuildInfo(ctx, client)
2124
if err != nil {
@@ -54,24 +57,15 @@ func ValidateServerCompatibility(ctx context.Context, client *mongo.Client) erro
5457
}
5558

5659
func ValidateUserRoles(ctx context.Context, client *mongo.Client) error {
57-
RequiredRoles := []string{"readAnyDatabase", "clusterMonitor"}
58-
5960
connectionStatus, err := GetConnectionStatus(ctx, client)
6061
if err != nil {
6162
return err
6263
}
6364

64-
hasRole := func(roles []Role, targetRole string) bool {
65-
for _, role := range roles {
66-
if role.Role == targetRole {
67-
return true
68-
}
69-
}
70-
return false
71-
}
72-
7365
for _, requiredRole := range RequiredRoles {
74-
if !hasRole(connectionStatus.AuthInfo.AuthenticatedUserRoles, requiredRole) {
66+
if !slices.ContainsFunc(connectionStatus.AuthInfo.AuthenticatedUserRoles, func(r Role) bool {
67+
return r.Role == requiredRole
68+
}) {
7569
return fmt.Errorf("missing required role: %s", requiredRole)
7670
}
7771
}

0 commit comments

Comments
 (0)