diff --git a/cmd/XDC/consolecmd.go b/cmd/XDC/consolecmd.go index b612fbd803d2..79d31318f26e 100644 --- a/cmd/XDC/consolecmd.go +++ b/cmd/XDC/consolecmd.go @@ -18,6 +18,7 @@ package main import ( "fmt" + "net/url" "slices" "strings" @@ -78,10 +79,11 @@ func localConsole(ctx *cli.Context) error { // Attach to the newly started node and create the JavaScript console. client := stack.Attach() config := console.Config{ - DataDir: utils.MakeDataDir(ctx), - DocRoot: ctx.String(utils.JSpathFlag.Name), - Client: client, - Preload: utils.MakeConsolePreloads(ctx), + DataDir: utils.MakeDataDir(ctx), + DocRoot: ctx.String(utils.JSpathFlag.Name), + Client: client, + LocalTransport: true, + Preload: utils.MakeConsolePreloads(ctx), } console, err := console.New(config) if err != nil { @@ -134,15 +136,16 @@ func remoteConsole(ctx *cli.Context) error { endpoint = cfg.IPCEndpoint() } - client, err := dialRPC(endpoint) + client, localTransport, err := dialRPC(endpoint) if err != nil { utils.Fatalf("Unable to attach to remote XDC: %v", err) } config := console.Config{ - DataDir: utils.MakeDataDir(ctx), - DocRoot: ctx.String(utils.JSpathFlag.Name), - Client: client, - Preload: utils.MakeConsolePreloads(ctx), + DataDir: utils.MakeDataDir(ctx), + DocRoot: ctx.String(utils.JSpathFlag.Name), + Client: client, + LocalTransport: localTransport, + Preload: utils.MakeConsolePreloads(ctx), } console, err := console.New(config) if err != nil { @@ -177,13 +180,33 @@ XDC --exec "%s" console`, b.String()) // dialRPC returns a RPC client which connects to the given endpoint. // The check for empty endpoint implements the defaulting logic // for "XDC attach" and "XDC monitor" with no argument. -func dialRPC(endpoint string) (*rpc.Client, error) { +func dialRPC(endpoint string) (*rpc.Client, bool, error) { + endpoint, localTransport := resolveConsoleEndpoint(endpoint) + client, err := rpc.Dial(endpoint) + return client, localTransport, err +} + +func resolveConsoleEndpoint(endpoint string) (string, bool) { if endpoint == "" { - endpoint = node.DefaultIPCEndpoint(clientIdentifier) - } else if strings.HasPrefix(endpoint, "rpc:") || strings.HasPrefix(endpoint, "ipc:") { - // Backwards compatibility with geth < 1.5 which required - // these prefixes. - endpoint = endpoint[4:] + return node.DefaultIPCEndpoint(clientIdentifier), true + } + if strings.HasPrefix(endpoint, "ipc:") { + return endpoint[4:], true + } + endpoint = strings.TrimPrefix(endpoint, "rpc:") + if endpoint == "stdio" { + return endpoint, false + } + u, err := url.Parse(endpoint) + if err != nil { + return endpoint, false + } + switch u.Scheme { + case "http", "https", "ws", "wss", "stdio": + return endpoint, false + case "": + return endpoint, true + default: + return endpoint, false } - return rpc.Dial(endpoint) } diff --git a/cmd/XDC/consolecmd_test.go b/cmd/XDC/consolecmd_test.go index 7dc76c498c3e..79d6081ff135 100644 --- a/cmd/XDC/consolecmd_test.go +++ b/cmd/XDC/consolecmd_test.go @@ -168,6 +168,77 @@ To exit, press ctrl-d or type exit attach.ExpectExit() } +func TestResolveConsoleEndpoint(t *testing.T) { + tests := []struct { + name string + endpoint string + wantEndpoint string + wantLocal bool + }{ + {name: "default ipc endpoint", endpoint: "", wantEndpoint: "", wantLocal: true}, + {name: "plain ipc path", endpoint: "/tmp/XDC.ipc", wantEndpoint: "/tmp/XDC.ipc", wantLocal: true}, + {name: "legacy ipc prefix", endpoint: "ipc:/tmp/XDC.ipc", wantEndpoint: "/tmp/XDC.ipc", wantLocal: true}, + {name: "legacy rpc prefix", endpoint: "rpc:/tmp/XDC.ipc", wantEndpoint: "/tmp/XDC.ipc", wantLocal: true}, + {name: "windows drive path stays unsupported", endpoint: `C:\\Users\\tester\\XDC.ipc`, wantEndpoint: `C:\\Users\\tester\\XDC.ipc`, wantLocal: false}, + {name: "windows drive slash path stays unsupported", endpoint: "C:/Users/tester/XDC.ipc", wantEndpoint: "C:/Users/tester/XDC.ipc", wantLocal: false}, + {name: "legacy rpc windows drive path stays unsupported", endpoint: `rpc:C:\\Users\\tester\\XDC.ipc`, wantEndpoint: `C:\\Users\\tester\\XDC.ipc`, wantLocal: false}, + {name: "legacy rpc http endpoint", endpoint: "rpc:http://localhost:8545", wantEndpoint: "http://localhost:8545", wantLocal: false}, + {name: "legacy rpc ws endpoint", endpoint: "rpc:ws://localhost:8546", wantEndpoint: "ws://localhost:8546", wantLocal: false}, + {name: "stdio endpoint", endpoint: "stdio", wantEndpoint: "stdio", wantLocal: false}, + {name: "legacy rpc stdio endpoint", endpoint: "rpc:stdio", wantEndpoint: "stdio", wantLocal: false}, + {name: "http endpoint", endpoint: "http://localhost:8545", wantEndpoint: "http://localhost:8545", wantLocal: false}, + {name: "ws endpoint", endpoint: "ws://localhost:8546", wantEndpoint: "ws://localhost:8546", wantLocal: false}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + gotEndpoint, gotLocal := resolveConsoleEndpoint(test.endpoint) + if gotLocal != test.wantLocal { + t.Fatalf("unexpected local transport classification: got %v want %v", gotLocal, test.wantLocal) + } + if test.wantEndpoint == "" { + if !strings.HasSuffix(gotEndpoint, "XDC.ipc") { + t.Fatalf("expected default IPC endpoint, got %q", gotEndpoint) + } + return + } + if gotEndpoint != test.wantEndpoint { + t.Fatalf("unexpected resolved endpoint: got %q want %q", gotEndpoint, test.wantEndpoint) + } + }) + } +} + +func TestDialRPCRejectsWindowsDrivePaths(t *testing.T) { + tests := []struct { + name string + endpoint string + }{ + {name: "windows drive path", endpoint: `C:\\Users\\tester\\XDC.ipc`}, + {name: "windows drive slash path", endpoint: "C:/Users/tester/XDC.ipc"}, + {name: "legacy rpc windows drive path", endpoint: `rpc:C:\\Users\\tester\\XDC.ipc`}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + client, local, err := dialRPC(test.endpoint) + if client != nil { + client.Close() + t.Fatal("expected dialRPC to reject Windows drive-letter path") + } + if err == nil { + t.Fatal("expected dialRPC to fail for Windows drive-letter path") + } + if local { + t.Fatal("expected Windows drive-letter path to stay classified as non-local") + } + if !strings.Contains(err.Error(), `no known transport for URL scheme "c"`) { + t.Fatalf("unexpected dialRPC error: %v", err) + } + }) + } +} + // trulyRandInt generates a crypto random integer used by the console tests to // not clash network ports with other tests running cocurrently. func trulyRandInt(lo, hi int) int { diff --git a/console/console.go b/console/console.go index 3ed11e0eac9e..f32cb92b306a 100644 --- a/console/console.go +++ b/console/console.go @@ -56,26 +56,28 @@ const DefaultPrompt = "> " // Config is the collection of configurations to fine tune the behavior of the // JavaScript console. type Config struct { - DataDir string // Data directory to store the console history at - DocRoot string // Filesystem path from where to load JavaScript files from - Client *rpc.Client // RPC client to execute Ethereum requests through - Prompt string // Input prompt prefix string (defaults to DefaultPrompt) - Prompter prompt.UserPrompter // Input prompter to allow interactive user feedback (defaults to TerminalPrompter) - Printer io.Writer // Output writer to serialize any display strings to (defaults to os.Stdout) - Preload []string // Absolute paths to JavaScript files to preload + DataDir string // Data directory to store the console history at + DocRoot string // Filesystem path from where to load JavaScript files from + Client *rpc.Client // RPC client to execute Ethereum requests through + LocalTransport bool // Whether the console is attached over an in-process or IPC transport + Prompt string // Input prompt prefix string (defaults to DefaultPrompt) + Prompter prompt.UserPrompter // Input prompter to allow interactive user feedback (defaults to TerminalPrompter) + Printer io.Writer // Output writer to serialize any display strings to (defaults to os.Stdout) + Preload []string // Absolute paths to JavaScript files to preload } // Console is a JavaScript interpreted runtime environment. It is a fully fleged // JavaScript console attached to a running node via an external or in-process RPC // client. type Console struct { - client *rpc.Client // RPC client to execute Ethereum requests through - jsre *jsre.JSRE // JavaScript runtime environment running the interpreter - prompt string // Input prompt prefix string - prompter prompt.UserPrompter // Input prompter to allow interactive user feedback - histPath string // Absolute path to the console scrollback history - history []string // Scroll history maintained by the console - printer io.Writer // Output writer to serialize any display strings to + client *rpc.Client // RPC client to execute Ethereum requests through + jsre *jsre.JSRE // JavaScript runtime environment running the interpreter + localTransport bool // Whether the connected transport is in-process or IPC + prompt string // Input prompt prefix string + prompter prompt.UserPrompter // Input prompter to allow interactive user feedback + histPath string // Absolute path to the console scrollback history + history []string // Scroll history maintained by the console + printer io.Writer // Output writer to serialize any display strings to interactiveStopped chan struct{} stopInteractiveCh chan struct{} @@ -103,6 +105,7 @@ func New(config Config) (*Console, error) { console := &Console{ client: config.Client, jsre: jsre.New(config.DocRoot, config.Printer), + localTransport: config.LocalTransport, prompt: config.Prompt, prompter: config.Prompter, printer: config.Printer, @@ -235,9 +238,41 @@ func (c *Console) initExtensions() error { } } }) + if !c.localTransport { + c.hideUnavailableDebugMethods() + } return nil } +func (c *Console) hideUnavailableDebugMethods() { + c.jsre.Do(func(vm *goja.Runtime) { + if _, err := vm.RunString(` + (function() { + function hideMethod(target, name) { + if (!target) { + return; + } + Object.defineProperty(target, name, { + value: undefined, + writable: true, + configurable: true, + enumerable: false + }); + } + + if (typeof debug !== "undefined") { + hideMethod(debug, "setHead"); + } + if (typeof web3 !== "undefined" && web3 && web3.debug) { + hideMethod(web3.debug, "setHead"); + } + })(); + `); err != nil { + panic(err) + } + }) +} + // initAdmin creates additional admin APIs implemented by the bridge. func (c *Console) initAdmin(vm *goja.Runtime, bridge *bridge) { if admin := getObject(vm, "admin"); admin != nil { @@ -288,7 +323,23 @@ func (c *Console) AutoCompleteInput(line string, pos int) (string, []string, str start++ break } - return line[:start], c.jsre.CompleteKeywords(line[start:pos]), line[pos:] + return line[:start], c.filterCompletions(c.jsre.CompleteKeywords(line[start:pos])), line[pos:] +} + +func (c *Console) filterCompletions(completions []string) []string { + if c.localTransport { + return completions + } + filtered := completions[:0] + for _, completion := range completions { + switch completion { + case "debug.setHead", "debug.setHead(", "debug.setHead.", "web3.debug.setHead", "web3.debug.setHead(", "web3.debug.setHead.": + continue + default: + filtered = append(filtered, completion) + } + } + return filtered } // Welcome show summary of current Geth instance and some metadata about the diff --git a/console/console_test.go b/console/console_test.go index 6da06c17c96d..173f6d8b4114 100644 --- a/console/console_test.go +++ b/console/console_test.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "os" + "slices" "strings" "testing" "time" @@ -28,6 +29,7 @@ import ( "github.com/XinFinOrg/XDPoSChain/XDCx" "github.com/XinFinOrg/XDPoSChain/XDCxlending" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/hexutil" "github.com/XinFinOrg/XDPoSChain/console/prompt" "github.com/XinFinOrg/XDPoSChain/core" "github.com/XinFinOrg/XDPoSChain/eth" @@ -35,6 +37,8 @@ import ( "github.com/XinFinOrg/XDPoSChain/internal/jsre" "github.com/XinFinOrg/XDPoSChain/miner" "github.com/XinFinOrg/XDPoSChain/node" + "github.com/XinFinOrg/XDPoSChain/rpc" + "github.com/dop251/goja" ) const ( @@ -123,12 +127,13 @@ func newTester(t *testing.T, confOverride func(*ethconfig.Config)) *tester { printer := new(bytes.Buffer) console, err := New(Config{ - DataDir: stack.DataDir(), - DocRoot: "testdata", - Client: client, - Prompter: prompter, - Printer: printer, - Preload: []string{"preload.js"}, + DataDir: stack.DataDir(), + DocRoot: "testdata", + Client: client, + LocalTransport: true, + Prompter: prompter, + Printer: printer, + Preload: []string{"preload.js"}, }) if err != nil { t.Fatalf("failed to create JavaScript console: %v", err) @@ -193,6 +198,167 @@ func TestEvaluate(t *testing.T) { } } +type debugPrintAndSetHeadRPC struct{} + +func (debugPrintAndSetHeadRPC) PrintBlock(uint64) (string, error) { + return "ok", nil +} + +func (debugPrintAndSetHeadRPC) SetHead(hexutil.Uint64) error { + return nil +} + +func TestConsoleHidesUnavailableDebugSetHead(t *testing.T) { + t.Run("hidden on remote transport", func(t *testing.T) { + console := newRPCConsole(t, debugPrintAndSetHeadRPC{}, false) + defer stopConsole(t, console) + assertDebugSetHeadVisible(t, console, false) + assertDebugSetHeadCompletion(t, console, false) + assertDebugSetHeadEnumerated(t, console, false) + assertDebugSetHeadPrettyPrinted(t, console, false) + }) + + t.Run("kept on local transport", func(t *testing.T) { + console := newRPCConsole(t, debugPrintAndSetHeadRPC{}, true) + defer stopConsole(t, console) + assertDebugSetHeadVisible(t, console, true) + assertDebugSetHeadCompletion(t, console, true) + assertDebugSetHeadEnumerated(t, console, true) + assertDebugSetHeadPrettyPrinted(t, console, true) + }) +} + +func newRPCConsole(t *testing.T, debugService interface{}, localTransport bool) *Console { + t.Helper() + + server := rpc.NewServer() + if err := server.RegisterName("debug", debugService); err != nil { + t.Fatalf("failed to register debug service: %v", err) + } + client := rpc.DialInProc(server) + t.Cleanup(func() { + client.Close() + }) + + console, err := New(Config{ + DataDir: t.TempDir(), + DocRoot: "testdata", + Client: client, + LocalTransport: localTransport, + Printer: new(bytes.Buffer), + }) + if err != nil { + t.Fatalf("failed to create console: %v", err) + } + return console +} + +func stopConsole(t *testing.T, console *Console) { + t.Helper() + if err := console.Stop(false); err != nil { + t.Fatalf("failed to stop console: %v", err) + } +} + +func assertDebugSetHeadVisible(t *testing.T, console *Console, want bool) { + t.Helper() + + console.jsre.Do(func(vm *goja.Runtime) { + debug := getObject(vm, "debug") + if debug == nil { + t.Fatal("debug object is not available") + } + got := !goja.IsUndefined(debug.Get("setHead")) + if got != want { + t.Fatalf("unexpected debug.setHead visibility: got %v want %v", got, want) + } + }) +} + +func assertDebugSetHeadCompletion(t *testing.T, console *Console, want bool) { + t.Helper() + + tests := []struct { + input string + want []string + }{ + {input: "debug.setH", want: []string{"debug.setHead", "debug.setHead(", "debug.setHead."}}, + {input: "debug.setHead", want: []string{"debug.setHead", "debug.setHead(", "debug.setHead."}}, + {input: "web3.debug.setH", want: []string{"web3.debug.setHead", "web3.debug.setHead(", "web3.debug.setHead."}}, + {input: "web3.debug.setHead", want: []string{"web3.debug.setHead", "web3.debug.setHead(", "web3.debug.setHead."}}, + } + for _, test := range tests { + _, completions, _ := console.AutoCompleteInput(test.input, len(test.input)) + got := false + for _, completion := range completions { + if slices.Contains(test.want, completion) { + got = true + break + } + } + if got != want { + t.Fatalf("unexpected debug.setHead completion visibility for %q: got %v want %v (completions=%v)", test.input, got, want, completions) + } + } +} + +func assertDebugSetHeadEnumerated(t *testing.T, console *Console, want bool) { + t.Helper() + + console.jsre.Do(func(vm *goja.Runtime) { + keys := getObject(vm, "Object") + if keys == nil { + t.Fatal("Object is not available") + } + keysFunc, ok := goja.AssertFunction(keys.Get("keys")) + if !ok { + t.Fatal("Object.keys is not available") + } + debug := getObject(vm, "debug") + if debug == nil { + t.Fatal("debug object is not available") + } + rv, err := keysFunc(goja.Undefined(), debug) + if err != nil { + t.Fatalf("Object.keys(debug) failed: %v", err) + } + got := false + switch keys := rv.Export().(type) { + case []any: + for _, key := range keys { + if key.(string) == "setHead" { + got = true + break + } + } + case []string: + if slices.Contains(keys, "setHead") { + got = true + } + default: + t.Fatalf("Object.keys(debug) returned unexpected type %T", keys) + } + if got != want { + t.Fatalf("unexpected debug.setHead enumeration visibility: got %v want %v", got, want) + } + }) +} + +func assertDebugSetHeadPrettyPrinted(t *testing.T, console *Console, want bool) { + t.Helper() + + printer, ok := console.printer.(*bytes.Buffer) + if !ok { + t.Fatal("console printer is not a bytes.Buffer") + } + printer.Reset() + console.Evaluate("debug") + got := strings.Contains(printer.String(), "setHead") + if got != want { + t.Fatalf("unexpected debug.setHead pretty-print visibility: got %v want %v output=%q", got, want, printer.String()) + } +} + // Tests that the console can be used in interactive mode. func TestInteractive(t *testing.T) { // Create a tester and run an interactive console in the background diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index cb6895740b78..92fa2e605ec2 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -2325,6 +2325,17 @@ func NewDebugAPI(b Backend) *DebugAPI { return &DebugAPI{b: b} } +// PrivateDebugAPI is the collection of debug APIs that are intended to be +// exposed only on local transports such as in-proc and IPC. +type PrivateDebugAPI struct { + b Backend +} + +// NewPrivateDebugAPI creates a new instance of PrivateDebugAPI. +func NewPrivateDebugAPI(b Backend) *PrivateDebugAPI { + return &PrivateDebugAPI{b: b} +} + // GetBlockRlp retrieves the RLP encoded for of a single block. func (api *DebugAPI) GetBlockRlp(ctx context.Context, number uint64) (string, error) { block, _ := api.b.BlockByNumber(ctx, rpc.BlockNumber(number)) @@ -2378,19 +2389,23 @@ func (api *DebugAPI) ChaindbCompact() error { return nil } -// SetHead rewinds the head of the blockchain to a previous block. -func (api *DebugAPI) SetHead(number hexutil.Uint64) error { - header := api.b.CurrentHeader() +func setHead(b Backend, number hexutil.Uint64) error { + header := b.CurrentHeader() if header == nil { return errors.New("current header is not available") } if header.Number.Uint64() <= uint64(number) { return errors.New("not allowed to rewind to a future block") } - api.b.SetHead(uint64(number)) + b.SetHead(uint64(number)) return nil } +// SetHead rewinds the head of the blockchain to a previous block. +func (api *PrivateDebugAPI) SetHead(number hexutil.Uint64) error { + return setHead(api.b, number) +} + // DbGet returns the raw value of a key stored in the database. func (api *DebugAPI) DbGet(key string) (hexutil.Bytes, error) { blob, err := common.ParseHexOrString(key) diff --git a/internal/ethapi/api_test.go b/internal/ethapi/api_test.go index 56188054da8d..efa7a00b6354 100644 --- a/internal/ethapi/api_test.go +++ b/internal/ethapi/api_test.go @@ -294,6 +294,45 @@ func TestRPCMarshalBlock(t *testing.T) { } } +func TestDebugSetHeadTransportExposure(t *testing.T) { + backend := newBackendMock() + apis := GetAPIs(backend, nil) + + openServer := rpc.NewServer() + localServer := rpc.NewServer() + for _, api := range apis { + if !api.Authenticated && !api.Local { + require.NoError(t, openServer.RegisterName(api.Namespace, api.Service)) + } + require.NoError(t, localServer.RegisterName(api.Namespace, api.Service)) + } + + openClient := rpc.DialInProc(openServer) + defer openClient.Close() + localClient := rpc.DialInProc(localServer) + defer localClient.Close() + + ctx := context.Background() + var block string + err := openClient.CallContext(ctx, &block, "debug_printBlock", uint64(0)) + if isMethodNotFound(err) { + t.Fatalf("expected debug_printBlock to remain exposed on open RPC, got %v", err) + } + + err = openClient.CallContext(ctx, nil, "debug_setHead", hexutil.Uint64(0)) + if !isMethodNotFound(err) { + t.Fatalf("expected debug_setHead to be hidden from open RPC, got %v", err) + } + + err = localClient.CallContext(ctx, nil, "debug_setHead", hexutil.Uint64(0)) + require.NoError(t, err) +} + +func isMethodNotFound(err error) bool { + rpcErr, ok := err.(rpc.Error) + return ok && rpcErr.ErrorCode() == -32601 +} + type testEngine struct{} func (testEngine) Author(header *types.Header) (common.Address, error) { return header.Coinbase, nil } diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go index ad2c0ab68f08..fc7a3a0df6f6 100644 --- a/internal/ethapi/backend.go +++ b/internal/ethapi/backend.go @@ -134,6 +134,10 @@ func GetAPIs(apiBackend Backend, chainReader consensus.ChainReader) []rpc.API { }, { Namespace: "debug", Service: NewDebugAPI(apiBackend), + }, { + Namespace: "debug", + Service: NewPrivateDebugAPI(apiBackend), + Local: true, }, { Namespace: "eth", Service: NewEthereumAccountAPI(apiBackend.AccountManager()), diff --git a/internal/jsre/pretty.go b/internal/jsre/pretty.go index e5eb4680d1fa..1fd0cbf8480a 100644 --- a/internal/jsre/pretty.go +++ b/internal/jsre/pretty.go @@ -220,12 +220,27 @@ func (ctx ppctx) fields(obj *goja.Object) []string { } } } - iterOwnAndConstructorKeys(ctx.vm, obj, add) + iterEnumerableAndConstructorKeys(ctx.vm, obj, add) slices.Sort(vals) slices.Sort(methods) return append(vals, methods...) } +func iterEnumerableAndConstructorKeys(vm *goja.Runtime, obj *goja.Object, f func(string)) { + shadowed := make(map[string]bool) + iterOwnKeys(vm, obj, func(prop string) { + shadowed[prop] = true + }) + iterEnumerableKeys(vm, obj, f) + if cp := constructorPrototype(vm, obj); cp != nil { + iterEnumerableKeys(vm, cp, func(prop string) { + if !shadowed[prop] { + f(prop) + } + }) + } +} + func iterOwnAndConstructorKeys(vm *goja.Runtime, obj *goja.Object, f func(string)) { seen := make(map[string]bool) iterOwnKeys(vm, obj, func(prop string) { @@ -241,6 +256,31 @@ func iterOwnAndConstructorKeys(vm *goja.Runtime, obj *goja.Object, f func(string } } +func iterEnumerableKeys(vm *goja.Runtime, obj *goja.Object, f func(string)) { + Object := vm.Get("Object").ToObject(vm) + keys, isFunc := goja.AssertFunction(Object.Get("keys")) + if !isFunc { + panic(vm.ToValue("Object.keys isn't a function")) + } + rv, err := keys(goja.Null(), obj) + if err != nil { + panic(vm.ToValue(fmt.Sprintf("Error getting enumerable object properties: %v", err))) + } + gv := rv.Export() + switch gv := gv.(type) { + case []interface{}: + for _, v := range gv { + f(v.(string)) + } + case []string: + for _, v := range gv { + f(v) + } + default: + panic(fmt.Errorf("Object.keys returned unexpected type %T", gv)) + } +} + func iterOwnKeys(vm *goja.Runtime, obj *goja.Object, f func(string)) { Object := vm.Get("Object").ToObject(vm) getOwnPropertyNames, isFunc := goja.AssertFunction(Object.Get("getOwnPropertyNames")) diff --git a/node/api.go b/node/api.go index 94a619196ec4..2d8f205b4449 100644 --- a/node/api.go +++ b/node/api.go @@ -221,7 +221,8 @@ func (api *adminAPI) StartHTTP(host *string, port *int, cors *string, apis *stri if err := api.node.http.setListenAddr(*host, *port); err != nil { return false, err } - if err := api.node.http.enableRPC(api.node.rpcAPIs, config); err != nil { + openApis, _, _, _ := api.node.getAPIs() + if err := api.node.http.enableRPC(openApis, config); err != nil { return false, err } if err := api.node.http.start(); err != nil { @@ -295,7 +296,7 @@ func (api *adminAPI) StartWS(host *string, port *int, allowedOrigins *string, ap if err := server.setListenAddr(*host, *port); err != nil { return false, err } - openApis, _ := api.node.getAPIs() + openApis, _, _, _ := api.node.getAPIs() if err := server.enableWS(openApis, config); err != nil { return false, err } diff --git a/node/api_test.go b/node/api_test.go index 77c30df2b5ba..e50d4cbb46de 100644 --- a/node/api_test.go +++ b/node/api_test.go @@ -18,6 +18,7 @@ package node import ( "bytes" + "context" "io" "net" "net/http" @@ -299,6 +300,57 @@ func TestStartRPC(t *testing.T) { } } +func TestStartHTTPLocalAPIsRemainHidden(t *testing.T) { + config := Config{} + config.NoUSB = true + config.P2P.NoDiscovery = true + config.HTTPTimeouts = rpc.DefaultHTTPTimeouts + + stack, err := New(&config) + if err != nil { + t.Fatal("can't create node:", err) + } + defer stack.Close() + + stack.RegisterAPIs([]rpc.API{{ + Namespace: "debug", + Version: "1.0", + Service: helloRPC("hello debug"), + Public: true, + Local: true, + }}) + + if err := stack.Start(); err != nil { + t.Fatal("can't start node:", err) + } + + _, err = (&adminAPI{stack}).StartHTTP(sp("127.0.0.1"), ip(0), nil, sp("debug"), nil) + assert.NoError(t, err) + + localClient := stack.Attach() + defer localClient.Close() + + var out string + err = localClient.CallContext(context.Background(), &out, "debug_helloWorld") + assert.NoError(t, err) + assert.Equal(t, "hello debug", out) + + httpClient, err := rpc.DialHTTP(stack.HTTPEndpoint()) + if err != nil { + t.Fatalf("failed to dial HTTP endpoint: %v", err) + } + defer httpClient.Close() + + err = httpClient.CallContext(context.Background(), &out, "debug_helloWorld") + if err == nil { + t.Fatal("expected local-only API to stay hidden from HTTP RPC started via admin API") + } + rpcErr, ok := err.(rpc.Error) + if !ok || rpcErr.ErrorCode() != -32601 { + t.Fatalf("expected method-not-found for hidden local-only API, got %v", err) + } +} + // checkReachable checks if the TCP endpoint in rawurl is open. func checkReachable(rawurl string) bool { u, err := url.Parse(rawurl) diff --git a/node/node.go b/node/node.go index 6614c8def645..a057a3d9e37a 100644 --- a/node/node.go +++ b/node/node.go @@ -378,21 +378,20 @@ func (n *Node) obtainJWTSecret(cliParam string) ([]byte, error) { // startup. It's not meant to be called at any time afterwards as it makes certain // assumptions about the state of the node. func (n *Node) startRPC() error { - if err := n.startInProc(n.rpcAPIs); err != nil { + openAPIs, authAPIs, localAPIs, hasAuthenticated := n.getAPIs() + + if err := n.startInProc(localAPIs); err != nil { return err } // Configure IPC. if n.ipc.endpoint != "" { - if err := n.ipc.start(n.rpcAPIs); err != nil { + if err := n.ipc.start(localAPIs); err != nil { return err } } - var ( - servers []*httpServer - openAPIs, allAPIs = n.getAPIs() - ) + var servers []*httpServer rpcConfig := rpcEndpointConfig{ batchItemLimit: n.config.BatchRequestLimit, @@ -445,7 +444,7 @@ func (n *Node) startRPC() error { batchResponseSizeLimit: engineAPIBatchResponseSizeLimit, httpBodyLimit: engineAPIBodyLimit, } - err := server.enableRPC(allAPIs, httpConfig{ + err := server.enableRPC(authAPIs, httpConfig{ CorsAllowedOrigins: DefaultAuthCors, Vhosts: n.config.AuthVirtualHosts, Modules: DefaultAuthModules, @@ -462,7 +461,7 @@ func (n *Node) startRPC() error { if err := server.setListenAddr(n.config.AuthAddr, port); err != nil { return err } - if err := server.enableWS(allAPIs, wsConfig{ + if err := server.enableWS(authAPIs, wsConfig{ Modules: DefaultAuthModules, Origins: DefaultAuthOrigins, prefix: DefaultAuthPrefix, @@ -488,7 +487,7 @@ func (n *Node) startRPC() error { } } // Configure authenticated API - if len(openAPIs) != len(allAPIs) { + if hasAuthenticated { jwtSecret, err := n.obtainJWTSecret(n.config.JWTSecret) if err != nil { return err @@ -582,15 +581,25 @@ func (n *Node) RegisterAPIs(apis []rpc.API) { n.rpcAPIs = append(n.rpcAPIs, apis...) } -// getAPIs return two sets of APIs, both the ones that do not require -// authentication, and the complete set -func (n *Node) getAPIs() (unauthenticated, all []rpc.API) { +// getAPIs splits the registered APIs by transport. +// Open APIs are exposed on unauthenticated HTTP/WS. +// Auth APIs are exposed on authenticated HTTP/WS. +// Local APIs are exposed on in-process and IPC transports. +func (n *Node) getAPIs() (open, auth, local []rpc.API, hasAuthenticated bool) { for _, api := range n.rpcAPIs { + local = append(local, api) + if api.Local { + continue + } + if api.Authenticated { + hasAuthenticated = true + } + auth = append(auth, api) if !api.Authenticated { - unauthenticated = append(unauthenticated, api) + open = append(open, api) } } - return unauthenticated, n.rpcAPIs + return open, auth, local, hasAuthenticated } // RegisterHandler mounts a handler on the given path on the canonical HTTP server. diff --git a/node/node_auth_test.go b/node/node_auth_test.go index ac51e3c51a25..6cfc252cd33d 100644 --- a/node/node_auth_test.go +++ b/node/node_auth_test.go @@ -197,6 +197,69 @@ func TestAuthEndpoints(t *testing.T) { } } +func TestLocalAPIsDoNotStartAuthEndpoints(t *testing.T) { + conf := &Config{ + HTTPHost: "127.0.0.1", + HTTPPort: 0, + WSHost: "127.0.0.1", + WSPort: 0, + AuthAddr: "127.0.0.1", + AuthPort: 0, + + HTTPModules: []string{"debug"}, + WSModules: []string{"debug"}, + } + node, err := New(conf) + if err != nil { + t.Fatalf("could not create a new node: %v", err) + } + node.RegisterAPIs([]rpc.API{{ + Namespace: "debug", + Version: "1.0", + Service: helloRPC("hello debug"), + Public: true, + Local: true, + Authenticated: true, + }}) + if err := node.Start(); err != nil { + t.Fatalf("failed to start test node: %v", err) + } + defer node.Close() + + if node.httpAuth.httpHandler.Load() != nil { + t.Fatal("expected auth HTTP handler to remain disabled for local-only APIs") + } + if node.wsAuth.wsHandler.Load() != nil { + t.Fatal("expected auth WS handler to remain disabled for local-only APIs") + } + + client := node.Attach() + defer client.Close() + + var out string + if err := client.CallContext(context.Background(), &out, "debug_helloWorld"); err != nil { + t.Fatalf("failed to call local-only API over in-process RPC: %v", err) + } + if out != "hello debug" { + t.Fatalf("unexpected local-only API result: %q", out) + } + + httpClient, err := rpc.DialHTTP(node.HTTPEndpoint()) + if err != nil { + t.Fatalf("failed to dial HTTP endpoint: %v", err) + } + defer httpClient.Close() + + err = httpClient.CallContext(context.Background(), &out, "debug_helloWorld") + if err == nil { + t.Fatal("expected local-only API to stay hidden from HTTP RPC") + } + rpcErr, ok := err.(rpc.Error) + if !ok || rpcErr.ErrorCode() != -32601 { + t.Fatalf("expected method-not-found for hidden local-only API, got %v", err) + } +} + func noneAuth(secret [32]byte) rpc.HTTPAuth { return func(header http.Header) error { token := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{ diff --git a/rpc/types.go b/rpc/types.go index f0b7c2cfb0fd..670e0b458999 100644 --- a/rpc/types.go +++ b/rpc/types.go @@ -35,6 +35,7 @@ type API struct { Version string // deprecated - this field is no longer used, but retained for compatibility Service interface{} // receiver instance which holds the methods Public bool // deprecated - this field is no longer used, but retained for compatibility + Local bool // whether the api should only be available over local transports (in-process and IPC). Authenticated bool // whether the api should only be available behind authentication. }