Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions cmd/github-mcp-server/generate_docs.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"

"github.com/github/github-mcp-server/pkg/github"
"github.com/github/github-mcp-server/pkg/lockdown"
"github.com/github/github-mcp-server/pkg/raw"
"github.com/github/github-mcp-server/pkg/toolsets"
"github.com/github/github-mcp-server/pkg/translations"
Expand Down Expand Up @@ -64,7 +65,8 @@ func generateReadmeDocs(readmePath string) error {
t, _ := translations.TranslationHelper()

// Create toolset group with mock clients
tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{})
repoAccessCache := lockdown.NewRepoAccessCache(nil)
tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{}, repoAccessCache)

// Generate toolsets documentation
toolsetsDoc := generateToolsetsDoc(tsg)
Expand Down Expand Up @@ -302,7 +304,8 @@ func generateRemoteToolsetsDoc() string {
t, _ := translations.TranslationHelper()

// Create toolset group with mock clients
tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{})
repoAccessCache := lockdown.NewRepoAccessCache(nil)
tsg := github.DefaultToolsetGroup(false, mockGetClient, mockGetGQLClient, mockGetRawClient, t, 5000, github.FeatureFlags{}, repoAccessCache)

// Generate table header
buf.WriteString("| Name | Description | API URL | 1-Click Install (VS Code) | Read-only Link | 1-Click Read-only Install (VS Code) |\n")
Expand Down
5 changes: 5 additions & 0 deletions cmd/github-mcp-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"os"
"strings"
"time"

"github.com/github/github-mcp-server/internal/ghmcp"
"github.com/github/github-mcp-server/pkg/github"
Expand Down Expand Up @@ -50,6 +51,7 @@ var (
enabledToolsets = []string{github.ToolsetMetadataDefault.ID}
}

ttl := viper.GetDuration("repo-access-cache-ttl")
stdioServerConfig := ghmcp.StdioServerConfig{
Version: version,
Host: viper.GetString("host"),
Expand All @@ -62,6 +64,7 @@ var (
LogFilePath: viper.GetString("log-file"),
ContentWindowSize: viper.GetInt("content-window-size"),
LockdownMode: viper.GetBool("lockdown-mode"),
RepoAccessCacheTTL: &ttl,
}
return ghmcp.RunStdioServer(stdioServerConfig)
},
Expand All @@ -84,6 +87,7 @@ func init() {
rootCmd.PersistentFlags().String("gh-host", "", "Specify the GitHub hostname (for GitHub Enterprise etc.)")
rootCmd.PersistentFlags().Int("content-window-size", 5000, "Specify the content window size")
rootCmd.PersistentFlags().Bool("lockdown-mode", false, "Enable lockdown mode")
rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)")

// Bind flag to viper
_ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets"))
Expand All @@ -95,6 +99,7 @@ func init() {
_ = viper.BindPFlag("host", rootCmd.PersistentFlags().Lookup("gh-host"))
_ = viper.BindPFlag("content-window-size", rootCmd.PersistentFlags().Lookup("content-window-size"))
_ = viper.BindPFlag("lockdown-mode", rootCmd.PersistentFlags().Lookup("lockdown-mode"))
_ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl"))

// Add subcommands
rootCmd.AddCommand(stdioCmd)
Expand Down
51 changes: 34 additions & 17 deletions internal/ghmcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/github/github-mcp-server/pkg/errors"
"github.com/github/github-mcp-server/pkg/github"
"github.com/github/github-mcp-server/pkg/lockdown"
mcplog "github.com/github/github-mcp-server/pkg/log"
"github.com/github/github-mcp-server/pkg/raw"
"github.com/github/github-mcp-server/pkg/translations"
Expand Down Expand Up @@ -54,6 +55,9 @@ type MCPServerConfig struct {

// LockdownMode indicates if we should enable lockdown mode
LockdownMode bool

// RepoAccessTTL overrides the default TTL for repository access cache entries.
RepoAccessTTL *time.Duration
}

const stdioServerLogPrefix = "stdioserver"
Expand All @@ -80,6 +84,14 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) {
},
} // We're going to wrap the Transport later in beforeInit
gqlClient := githubv4.NewEnterpriseClient(apiHost.graphqlURL.String(), gqlHTTPClient)
repoAccessOpts := []lockdown.RepoAccessOption{}
if cfg.RepoAccessTTL != nil {
repoAccessOpts = append(repoAccessOpts, lockdown.WithTTL(*cfg.RepoAccessTTL))
}
var repoAccessCache *lockdown.RepoAccessCache
if cfg.LockdownMode {
repoAccessCache = lockdown.NewRepoAccessCache(gqlClient, repoAccessOpts...)
}

// When a client send an initialize request, update the user agent to include the client info.
beforeInit := func(_ context.Context, _ any, message *mcp.InitializeRequest) {
Expand Down Expand Up @@ -165,6 +177,7 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) {
cfg.Translator,
cfg.ContentWindowSize,
github.FeatureFlags{LockdownMode: cfg.LockdownMode},
repoAccessCache,
)
err = tsg.EnableToolsets(enabledToolsets, nil)

Expand Down Expand Up @@ -219,6 +232,9 @@ type StdioServerConfig struct {

// LockdownMode indicates if we should enable lockdown mode
LockdownMode bool

// RepoAccessCacheTTL overrides the default TTL for repository access cache entries.
RepoAccessCacheTTL *time.Duration
}

// RunStdioServer is not concurrent safe.
Expand All @@ -229,23 +245,6 @@ func RunStdioServer(cfg StdioServerConfig) error {

t, dumpTranslations := translations.TranslationHelper()

ghServer, err := NewMCPServer(MCPServerConfig{
Version: cfg.Version,
Host: cfg.Host,
Token: cfg.Token,
EnabledToolsets: cfg.EnabledToolsets,
DynamicToolsets: cfg.DynamicToolsets,
ReadOnly: cfg.ReadOnly,
Translator: t,
ContentWindowSize: cfg.ContentWindowSize,
LockdownMode: cfg.LockdownMode,
})
if err != nil {
return fmt.Errorf("failed to create MCP server: %w", err)
}

stdioServer := server.NewStdioServer(ghServer)

var slogHandler slog.Handler
var logOutput io.Writer
if cfg.LogFilePath != "" {
Expand All @@ -262,6 +261,24 @@ func RunStdioServer(cfg StdioServerConfig) error {
logger := slog.New(slogHandler)
logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "dynamicToolsets", cfg.DynamicToolsets, "readOnly", cfg.ReadOnly, "lockdownEnabled", cfg.LockdownMode)
stdLogger := log.New(logOutput, stdioServerLogPrefix, 0)

ghServer, err := NewMCPServer(MCPServerConfig{
Version: cfg.Version,
Host: cfg.Host,
Token: cfg.Token,
EnabledToolsets: cfg.EnabledToolsets,
DynamicToolsets: cfg.DynamicToolsets,
ReadOnly: cfg.ReadOnly,
Translator: t,
ContentWindowSize: cfg.ContentWindowSize,
LockdownMode: cfg.LockdownMode,
RepoAccessTTL: cfg.RepoAccessCacheTTL,
})
if err != nil {
return fmt.Errorf("failed to create MCP server: %w", err)
}

stdioServer := server.NewStdioServer(ghServer)
stdioServer.SetErrorLogger(stdLogger)

if cfg.ExportTranslations {
Expand Down
85 changes: 73 additions & 12 deletions pkg/github/issues.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ func fragmentToIssue(fragment IssueFragment) *github.Issue {
}

// GetIssue creates a tool to get details of a specific issue in a GitHub repository.
func IssueRead(getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc, flags FeatureFlags) (tool mcp.Tool, handler server.ToolHandlerFunc) {
func IssueRead(getClient GetClientFn, getGQLClient GetGQLClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("issue_read",
mcp.WithDescription(t("TOOL_ISSUE_READ_DESCRIPTION", "Get information about a specific issue in a GitHub repository.")),
mcp.WithToolAnnotation(mcp.ToolAnnotation{
Expand Down Expand Up @@ -297,20 +297,20 @@ Options are:

switch method {
case "get":
return GetIssue(ctx, client, gqlClient, owner, repo, issueNumber, flags)
return GetIssue(ctx, client, cache, owner, repo, issueNumber, flags)
case "get_comments":
return GetIssueComments(ctx, client, owner, repo, issueNumber, pagination, flags)
return GetIssueComments(ctx, client, cache, owner, repo, issueNumber, pagination, flags)
case "get_sub_issues":
return GetSubIssues(ctx, client, owner, repo, issueNumber, pagination, flags)
return GetSubIssues(ctx, client, cache, owner, repo, issueNumber, pagination, flags)
case "get_labels":
return GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber, flags)
return GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber)
default:
return mcp.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil
}
}
}

func GetIssue(ctx context.Context, client *github.Client, gqlClient *githubv4.Client, owner string, repo string, issueNumber int, flags FeatureFlags) (*mcp.CallToolResult, error) {
func GetIssue(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner string, repo string, issueNumber int, flags FeatureFlags) (*mcp.CallToolResult, error) {
issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber)
if err != nil {
return nil, fmt.Errorf("failed to get issue: %w", err)
Expand All @@ -326,12 +326,16 @@ func GetIssue(ctx context.Context, client *github.Client, gqlClient *githubv4.Cl
}

if flags.LockdownMode {
if issue.User != nil {
shouldRemoveContent, err := lockdown.ShouldRemoveContent(ctx, gqlClient, *issue.User.Login, owner, repo)
if cache == nil {
return nil, fmt.Errorf("lockdown cache is not configured")
}
login := issue.GetUser().GetLogin()
if login != "" {
isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, login, owner, repo)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil
}
if shouldRemoveContent {
if !isPrivate && !hasPushAccess {
return mcp.NewToolResultError("access to issue details is restricted by lockdown mode"), nil
}
}
Expand All @@ -355,7 +359,7 @@ func GetIssue(ctx context.Context, client *github.Client, gqlClient *githubv4.Cl
return mcp.NewToolResultText(string(r)), nil
}

func GetIssueComments(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int, pagination PaginationParams, _ FeatureFlags) (*mcp.CallToolResult, error) {
func GetIssueComments(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner string, repo string, issueNumber int, pagination PaginationParams, flags FeatureFlags) (*mcp.CallToolResult, error) {
opts := &github.IssueListCommentsOptions{
ListOptions: github.ListOptions{
Page: pagination.Page,
Expand All @@ -376,6 +380,34 @@ func GetIssueComments(ctx context.Context, client *github.Client, owner string,
}
return mcp.NewToolResultError(fmt.Sprintf("failed to get issue comments: %s", string(body))), nil
}
if flags.LockdownMode {
if cache == nil {
return nil, fmt.Errorf("lockdown cache is not configured")
}
filteredComments := make([]*github.IssueComment, 0, len(comments))
for _, comment := range comments {
user := comment.User
if user == nil {
continue
}
login := user.GetLogin()
if login == "" {
continue
}
isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, login, owner, repo)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil
}
// Do not filter content for private repositories
if isPrivate {
break
}
if hasPushAccess {
filteredComments = append(filteredComments, comment)
}
}
comments = filteredComments
}

r, err := json.Marshal(comments)
if err != nil {
Expand All @@ -385,7 +417,7 @@ func GetIssueComments(ctx context.Context, client *github.Client, owner string,
return mcp.NewToolResultText(string(r)), nil
}

func GetSubIssues(ctx context.Context, client *github.Client, owner string, repo string, issueNumber int, pagination PaginationParams, _ FeatureFlags) (*mcp.CallToolResult, error) {
func GetSubIssues(ctx context.Context, client *github.Client, cache *lockdown.RepoAccessCache, owner string, repo string, issueNumber int, pagination PaginationParams, featureFlags FeatureFlags) (*mcp.CallToolResult, error) {
opts := &github.IssueListOptions{
ListOptions: github.ListOptions{
Page: pagination.Page,
Expand All @@ -412,6 +444,35 @@ func GetSubIssues(ctx context.Context, client *github.Client, owner string, repo
return mcp.NewToolResultError(fmt.Sprintf("failed to list sub-issues: %s", string(body))), nil
}

if featureFlags.LockdownMode {
if cache == nil {
return nil, fmt.Errorf("lockdown cache is not configured")
}
filteredSubIssues := make([]*github.SubIssue, 0, len(subIssues))
for _, subIssue := range subIssues {
user := subIssue.User
if user == nil {
continue
}
login := user.GetLogin()
if login == "" {
continue
}
isPrivate, hasPushAccess, err := cache.GetRepoAccessInfo(ctx, login, owner, repo)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to check lockdown mode: %v", err)), nil
}
// Repo is private, do not filter content
if isPrivate {
break
}
if hasPushAccess {
filteredSubIssues = append(filteredSubIssues, subIssue)
}
}
subIssues = filteredSubIssues
}

r, err := json.Marshal(subIssues)
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
Expand All @@ -420,7 +481,7 @@ func GetSubIssues(ctx context.Context, client *github.Client, owner string, repo
return mcp.NewToolResultText(string(r)), nil
}

func GetIssueLabels(ctx context.Context, client *githubv4.Client, owner string, repo string, issueNumber int, _ FeatureFlags) (*mcp.CallToolResult, error) {
func GetIssueLabels(ctx context.Context, client *githubv4.Client, owner string, repo string, issueNumber int) (*mcp.CallToolResult, error) {
// Get current labels on the issue using GraphQL
var query struct {
Repository struct {
Expand Down
16 changes: 8 additions & 8 deletions pkg/github/issues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func Test_GetIssue(t *testing.T) {
// Verify tool definition once
mockClient := github.NewClient(nil)
defaultGQLClient := githubv4.NewClient(nil)
tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(defaultGQLClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))
tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(defaultGQLClient), stubRepoAccessCache(defaultGQLClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))
require.NoError(t, toolsnaps.Test(tool.Name, tool))

assert.Equal(t, "issue_read", tool.Name)
Expand Down Expand Up @@ -212,7 +212,7 @@ func Test_GetIssue(t *testing.T) {
}

flags := stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled})
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, flags)
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, flags)

request := createMCPRequest(tc.requestArgs)
result, err := handler(context.Background(), request)
Expand Down Expand Up @@ -1710,7 +1710,7 @@ func Test_GetIssueComments(t *testing.T) {
// Verify tool definition once
mockClient := github.NewClient(nil)
gqlClient := githubv4.NewClient(nil)
tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))
tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))
require.NoError(t, toolsnaps.Test(tool.Name, tool))

assert.Equal(t, "issue_read", tool.Name)
Expand Down Expand Up @@ -1816,7 +1816,7 @@ func Test_GetIssueComments(t *testing.T) {
// Setup client with mock
client := github.NewClient(tc.mockedClient)
gqlClient := githubv4.NewClient(nil)
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))

// Create call request
request := createMCPRequest(tc.requestArgs)
Expand Down Expand Up @@ -1853,7 +1853,7 @@ func Test_GetIssueLabels(t *testing.T) {
// Verify tool definition
mockGQClient := githubv4.NewClient(nil)
mockClient := github.NewClient(nil)
tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))
tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(mockGQClient), stubRepoAccessCache(mockGQClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))
require.NoError(t, toolsnaps.Test(tool.Name, tool))

assert.Equal(t, "issue_read", tool.Name)
Expand Down Expand Up @@ -1928,7 +1928,7 @@ func Test_GetIssueLabels(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
gqlClient := githubv4.NewClient(tc.mockedClient)
client := github.NewClient(nil)
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))

request := createMCPRequest(tc.requestArgs)
result, err := handler(context.Background(), request)
Expand Down Expand Up @@ -2619,7 +2619,7 @@ func Test_GetSubIssues(t *testing.T) {
// Verify tool definition once
mockClient := github.NewClient(nil)
gqlClient := githubv4.NewClient(nil)
tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))
tool, _ := IssueRead(stubGetClientFn(mockClient), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))
require.NoError(t, toolsnaps.Test(tool.Name, tool))

assert.Equal(t, "issue_read", tool.Name)
Expand Down Expand Up @@ -2816,7 +2816,7 @@ func Test_GetSubIssues(t *testing.T) {
// Setup client with mock
client := github.NewClient(tc.mockedClient)
gqlClient := githubv4.NewClient(nil)
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))

// Create call request
request := createMCPRequest(tc.requestArgs)
Expand Down
Loading
Loading