diff --git a/crates/bashkit-cli/src/mcp.rs b/crates/bashkit-cli/src/mcp.rs index f465c617..b52c98cc 100644 --- a/crates/bashkit-cli/src/mcp.rs +++ b/crates/bashkit-cli/src/mcp.rs @@ -709,7 +709,7 @@ mod tests { fn make_test_tool() -> ScriptedTool { ScriptedTool::builder("test_api") .short_description("Test API tool") - .tool(ToolDef::new("greet", "Greet someone"), |args: &ToolArgs| { + .tool_fn(ToolDef::new("greet", "Greet someone"), |args: &ToolArgs| { let name = args.param_str("name").unwrap_or("world"); Ok(format!("hello {name}\n")) }) @@ -759,7 +759,7 @@ mod tests { let mut server = McpServer::new(bashkit::Bash::new); let tool = ScriptedTool::builder("err_api") .short_description("Error API") - .tool(ToolDef::new("fail", "Always fails"), |_args: &ToolArgs| { + .tool_fn(ToolDef::new("fail", "Always fails"), |_args: &ToolArgs| { Err("service down".to_string()) }) .build(); diff --git a/crates/bashkit-eval/src/scripting_agent.rs b/crates/bashkit-eval/src/scripting_agent.rs index 67e7d44d..6d468dc0 100644 --- a/crates/bashkit-eval/src/scripting_agent.rs +++ b/crates/bashkit-eval/src/scripting_agent.rs @@ -153,7 +153,7 @@ pub async fn run_scripted_agent( for mock_tool in &task.tools { let def = build_tool_def(mock_tool); let callback = make_mock_callback(mock_tool.mock.clone()); - builder = builder.tool(def, move |args: &ToolArgs| callback(args)); + builder = builder.tool_fn(def, move |args: &ToolArgs| callback(args)); } if task.discovery_mode { builder = builder.with_discovery(); diff --git a/crates/bashkit-js/src/lib.rs b/crates/bashkit-js/src/lib.rs index 9946b079..eb76e281 100644 --- a/crates/bashkit-js/src/lib.rs +++ b/crates/bashkit-js/src/lib.rs @@ -1544,7 +1544,7 @@ impl ScriptedTool { .map_err(|_| format!("{}: callback channel closed", tool_name))? }; - builder = builder.tool( + builder = builder.tool_fn( ToolDef::new(&entry.name, &entry.description).with_schema(entry.schema.clone()), callback, ); diff --git a/crates/bashkit-python/src/lib.rs b/crates/bashkit-python/src/lib.rs index 74cdc3e9..8f4ffc8c 100644 --- a/crates/bashkit-python/src/lib.rs +++ b/crates/bashkit-python/src/lib.rs @@ -1577,7 +1577,7 @@ impl ScriptedTool { }) }) }; - builder = builder.tool(def, callback); + builder = builder.tool_fn(def, callback); } else { // Sync callback: ctx.run(fn, params, stdin) with ContextVars. let callback = move |args: &ToolArgs| -> Result { @@ -1593,7 +1593,7 @@ impl ScriptedTool { }) }) }; - builder = builder.tool(def, callback); + builder = builder.tool_fn(def, callback); } } diff --git a/crates/bashkit/examples/scripted_tool.rs b/crates/bashkit/examples/scripted_tool.rs index 0c13d8b1..290cee71 100644 --- a/crates/bashkit/examples/scripted_tool.rs +++ b/crates/bashkit/examples/scripted_tool.rs @@ -29,10 +29,10 @@ async fn main() -> anyhow::Result<()> { // In production the callbacks would call real APIs. let tool = ScriptedTool::builder("ecommerce_api") .short_description("E-commerce API orchestrator with user, order, and inventory tools") - .tool(fakes::get_user_def(), fakes::get_user) - .tool(fakes::list_orders_def(), fakes::list_orders) - .tool(fakes::get_inventory_def(), fakes::get_inventory) - .tool(fakes::create_discount_def(), fakes::create_discount) + .tool_fn(fakes::get_user_def(), fakes::get_user) + .tool_fn(fakes::list_orders_def(), fakes::list_orders) + .tool_fn(fakes::get_inventory_def(), fakes::get_inventory) + .tool_fn(fakes::create_discount_def(), fakes::create_discount) .env("STORE_NAME", "Bashkit Shop") .build(); diff --git a/crates/bashkit/src/lib.rs b/crates/bashkit/src/lib.rs index d44622d1..bc034cf6 100644 --- a/crates/bashkit/src/lib.rs +++ b/crates/bashkit/src/lib.rs @@ -418,6 +418,9 @@ mod snapshot; mod ssh; /// Tool contract for LLM integration pub mod tool; +/// Reusable tool primitives: ToolDef, ToolArgs, ToolImpl, exec types. +#[cfg(feature = "scripted_tool")] +pub(crate) mod tool_def; /// Structured execution trace events. pub mod trace; @@ -456,6 +459,8 @@ pub use scripted_tool::{ ScriptedCommandKind, ScriptedExecutionTrace, ScriptedTool, ScriptedToolBuilder, ScriptingToolSet, ScriptingToolSetBuilder, ToolArgs, ToolCallback, ToolDef, }; +#[cfg(feature = "scripted_tool")] +pub use tool_def::{AsyncToolExec, SyncToolExec, ToolImpl}; #[cfg(feature = "http_client")] pub use network::{HttpClient, HttpHandler}; diff --git a/crates/bashkit/src/scripted_tool/execute.rs b/crates/bashkit/src/scripted_tool/execute.rs index 1f67904a..5abf68d6 100644 --- a/crates/bashkit/src/scripted_tool/execute.rs +++ b/crates/bashkit/src/scripted_tool/execute.rs @@ -1,4 +1,4 @@ -//! ScriptedTool execution: Tool impl, builtin adapter, flag parser, documentation helpers. +//! ScriptedTool execution: Tool impl, builtin adapter, documentation helpers. use super::{ CallbackKind, ScriptedCommandInvocation, ScriptedCommandKind, ScriptedExecutionTrace, @@ -12,6 +12,7 @@ use crate::tool::{ Tool, ToolError, ToolExecution, ToolOutputChunk, ToolRequest, ToolResponse, ToolStatus, VERSION, localized, tool_output_from_response, tool_request_from_value, }; +use crate::tool_def::{parse_flags, usage_from_schema}; use async_trait::async_trait; use schemars::schema_for; use std::sync::{Arc, Mutex}; @@ -34,110 +35,6 @@ fn push_invocation( }); } -// ============================================================================ -// Flag parser — `--key value` / `--key=value` → JSON object -// ============================================================================ - -/// Parse `--key value` and `--key=value` flags into a JSON object. -/// Types are coerced according to the schema's property definitions. -/// Unknown flags (not in schema) are kept as strings. -/// Bare `--flag` without a value is treated as `true` if the schema says boolean, -/// otherwise as `true` when the next arg also starts with `--` or is absent. -fn parse_flags( - raw_args: &[String], - schema: &serde_json::Value, -) -> std::result::Result { - let properties = schema - .get("properties") - .and_then(|p| p.as_object()) - .cloned() - .unwrap_or_default(); - - let mut result = serde_json::Map::new(); - let mut i = 0; - - while i < raw_args.len() { - let arg = &raw_args[i]; - - let Some(flag) = arg.strip_prefix("--") else { - return Err(format!("expected --flag, got: {arg}")); - }; - - // --key=value - if let Some((key, raw_value)) = flag.split_once('=') { - let value = coerce_value(raw_value, properties.get(key)); - result.insert(key.to_string(), value); - i += 1; - continue; - } - - // --flag (boolean) or --key value - let key = flag; - let prop_schema = properties.get(key); - let is_boolean = prop_schema - .and_then(|s| s.get("type")) - .and_then(|t| t.as_str()) - == Some("boolean"); - - if is_boolean { - result.insert(key.to_string(), serde_json::Value::Bool(true)); - i += 1; - } else if i + 1 < raw_args.len() && !raw_args[i + 1].starts_with("--") { - let raw_value = &raw_args[i + 1]; - let value = coerce_value(raw_value, prop_schema); - result.insert(key.to_string(), value); - i += 2; - } else { - // No value follows and not boolean — treat as true - result.insert(key.to_string(), serde_json::Value::Bool(true)); - i += 1; - } - } - - Ok(serde_json::Value::Object(result)) -} - -/// Coerce a raw string value to the type declared in the property schema. -fn coerce_value(raw: &str, prop_schema: Option<&serde_json::Value>) -> serde_json::Value { - let type_str = prop_schema - .and_then(|s| s.get("type")) - .and_then(|t| t.as_str()) - .unwrap_or("string"); - - match type_str { - "integer" => raw - .parse::() - .map(serde_json::Value::from) - .unwrap_or_else(|_| serde_json::Value::String(raw.to_string())), - "number" => raw - .parse::() - .map(|n| serde_json::json!(n)) - .unwrap_or_else(|_| serde_json::Value::String(raw.to_string())), - "boolean" => match raw { - "true" | "1" | "yes" => serde_json::Value::Bool(true), - "false" | "0" | "no" => serde_json::Value::Bool(false), - _ => serde_json::Value::String(raw.to_string()), - }, - _ => serde_json::Value::String(raw.to_string()), - } -} - -/// Generate a usage hint from schema properties: `--id --name `. -fn usage_from_schema(schema: &serde_json::Value) -> Option { - let props = schema.get("properties")?.as_object()?; - if props.is_empty() { - return None; - } - let flags: Vec = props - .iter() - .map(|(key, prop)| { - let ty = prop.get("type").and_then(|t| t.as_str()).unwrap_or("value"); - format!("--{key} <{ty}>") - }) - .collect(); - Some(flags.join(" ")) -} - // ============================================================================ // ToolBuiltinAdapter — wraps ToolCallback as a Builtin // ============================================================================ @@ -795,7 +692,7 @@ mod tests { fn build_help_test_tool() -> ScriptedTool { ScriptedTool::builder("test_api") .short_description("Test API") - .tool( + .tool_fn( ToolDef::new("get_user", "Fetch user by ID").with_schema(serde_json::json!({ "type": "object", "properties": { @@ -804,7 +701,7 @@ mod tests { })), |_args: &super::ToolArgs| Ok("{\"id\":1}\n".to_string()), ) - .tool( + .tool_fn( ToolDef::new("list_orders", "List orders for user").with_schema( serde_json::json!({ "type": "object", @@ -909,7 +806,7 @@ mod tests { async fn test_compact_prompt_omits_usage() { let tool = ScriptedTool::builder("compact_test") .compact_prompt(true) - .tool( + .tool_fn( ToolDef::new("get_user", "Fetch user").with_schema(serde_json::json!({ "type": "object", "properties": { "id": {"type": "integer"} } @@ -925,7 +822,7 @@ mod tests { #[tokio::test] async fn test_non_compact_prompt_has_usage() { let tool = ScriptedTool::builder("full_test") - .tool( + .tool_fn( ToolDef::new("get_user", "Fetch user").with_schema(serde_json::json!({ "type": "object", "properties": { "id": {"type": "integer"} } @@ -945,7 +842,7 @@ mod tests { let tool = ScriptedTool::builder("test") .short_description("test") - .tool( + .tool_fn( ToolDef::new("fail", "Always fails"), |_args: &super::ToolArgs| Err("service error".to_string()), ) @@ -969,31 +866,31 @@ mod tests { fn build_discover_test_tool() -> ScriptedTool { ScriptedTool::builder("big_api") .short_description("Big API") - .tool( + .tool_fn( ToolDef::new("create_charge", "Create a payment charge") .with_category("payments") .with_tags(&["billing", "write"]), |_args: &super::ToolArgs| Ok("ok\n".to_string()), ) - .tool( + .tool_fn( ToolDef::new("refund", "Issue a refund") .with_category("payments") .with_tags(&["billing", "write"]), |_args: &super::ToolArgs| Ok("ok\n".to_string()), ) - .tool( + .tool_fn( ToolDef::new("get_user", "Fetch user by ID") .with_category("users") .with_tags(&["read"]), |_args: &super::ToolArgs| Ok("ok\n".to_string()), ) - .tool( + .tool_fn( ToolDef::new("delete_user", "Delete a user account") .with_category("users") .with_tags(&["admin", "write"]), |_args: &super::ToolArgs| Ok("ok\n".to_string()), ) - .tool( + .tool_fn( ToolDef::new("get_inventory", "Check inventory levels").with_category("inventory"), |_args: &super::ToolArgs| Ok("ok\n".to_string()), ) @@ -1154,7 +1051,7 @@ mod tests { #[tokio::test] async fn test_callback_error_sanitized_by_default() { let tool = ScriptedTool::builder("api") - .tool( + .tool_fn( ToolDef::new("fail", "Always fails"), |_args: &super::ToolArgs| { Err("connection failed: postgres://admin:secret@internal-db:5432/prod".into()) @@ -1181,7 +1078,7 @@ mod tests { async fn test_callback_error_unsanitized_when_disabled() { let tool = ScriptedTool::builder("api") .sanitize_errors(false) - .tool( + .tool_fn( ToolDef::new("fail", "Always fails"), |_args: &super::ToolArgs| { Err("connection failed: postgres://admin:secret@internal-db:5432/prod".into()) diff --git a/crates/bashkit/src/scripted_tool/mod.rs b/crates/bashkit/src/scripted_tool/mod.rs index 2f80c72c..66eafd41 100644 --- a/crates/bashkit/src/scripted_tool/mod.rs +++ b/crates/bashkit/src/scripted_tool/mod.rs @@ -35,7 +35,7 @@ //! //! # tokio_test::block_on(async { //! let tool = ScriptedTool::builder("api") -//! .tool( +//! .tool_fn( //! ToolDef::new("greet", "Greet a user") //! .with_schema(serde_json::json!({ //! "type": "object", @@ -75,7 +75,7 @@ //! let k = api_key.clone(); //! let u = base_url.clone(); //! let mut builder = ScriptedTool::builder("api"); -//! builder = builder.tool( +//! builder = builder.tool_fn( //! ToolDef::new("get_user", "Fetch user by ID"), //! move |args: &ToolArgs| { //! let _key = &*k; // shared API key @@ -86,7 +86,7 @@ //! //! let k2 = api_key.clone(); //! let u2 = base_url.clone(); -//! builder = builder.tool( +//! builder = builder.tool_fn( //! ToolDef::new("list_orders", "List orders"), //! move |_args: &ToolArgs| { //! let _key = &*k2; @@ -106,7 +106,7 @@ //! let call_count = Arc::new(Mutex::new(0u64)); //! let c = call_count.clone(); //! let tool = ScriptedTool::builder("api") -//! .tool( +//! .tool_fn( //! ToolDef::new("tracked", "Counted call"), //! move |_args: &ToolArgs| { //! let mut count = c.lock().unwrap(); @@ -132,131 +132,23 @@ mod toolset; pub use toolset::{DiscoverTool, DiscoveryMode, ScriptingToolSet, ScriptingToolSetBuilder}; +// Re-export foundational types from tool_def (they used to live here). +pub use crate::tool_def::{ + AsyncToolCallback, AsyncToolExec, SyncToolExec, ToolArgs, ToolCallback, ToolDef, ToolImpl, +}; + use crate::{ExecutionLimits, Tool, ToolService}; use schemars::schema_for; use serde::{Deserialize, Serialize}; -use std::future::Future; -use std::pin::Pin; use std::sync::{Arc, Mutex}; -// ============================================================================ -// ToolDef — OpenAPI-style tool definition -// ============================================================================ - -/// OpenAPI-style tool definition: name, description, input schema. -/// -/// Describes a sub-tool registered with [`ScriptedToolBuilder`]. -/// The `input_schema` is optional JSON Schema for documentation / LLM prompts -/// and for type coercion of `--key value` flags. -#[derive(Clone)] -pub struct ToolDef { - /// Command name used as bash builtin (e.g. `"get_user"`). - pub name: String, - /// Human-readable description for LLM consumption. - pub description: String, - /// JSON Schema describing accepted arguments. Empty object if unspecified. - pub input_schema: serde_json::Value, - /// Categorical tags for discovery (e.g. `["admin", "billing"]`). - pub tags: Vec, - /// Grouping category for discovery (e.g. `"payments"`). - pub category: Option, -} - -impl ToolDef { - /// Create a tool definition with name and description. - pub fn new(name: impl Into, description: impl Into) -> Self { - Self { - name: name.into(), - description: description.into(), - input_schema: serde_json::Value::Object(Default::default()), - tags: Vec::new(), - category: None, - } - } - - /// Attach a JSON Schema for the tool's input parameters. - pub fn with_schema(mut self, schema: serde_json::Value) -> Self { - self.input_schema = schema; - self - } - - /// Add categorical tags for discovery filtering. - pub fn with_tags(mut self, tags: &[&str]) -> Self { - self.tags = tags.iter().map(|s| s.to_string()).collect(); - self - } - - /// Set the grouping category for discovery. - pub fn with_category(mut self, category: &str) -> Self { - self.category = Some(category.to_string()); - self - } -} - -// ============================================================================ -// ToolArgs — parsed arguments passed to callbacks -// ============================================================================ - -/// Parsed arguments passed to a tool callback. -/// -/// `params` is a JSON object built from `--key value` flags, with values -/// type-coerced per the `ToolDef`'s `input_schema`. -/// `stdin` carries pipeline input from a prior command, if any. -pub struct ToolArgs { - /// Parsed parameters as a JSON object. Keys from `--key value` flags. - pub params: serde_json::Value, - /// Pipeline input from a prior command (e.g. `echo data | tool`). - pub stdin: Option, -} - -impl ToolArgs { - /// Get a string parameter by name. - pub fn param_str(&self, key: &str) -> Option<&str> { - self.params.get(key).and_then(|v| v.as_str()) - } - - /// Get an integer parameter by name. - pub fn param_i64(&self, key: &str) -> Option { - self.params.get(key).and_then(|v| v.as_i64()) - } - - /// Get a float parameter by name. - pub fn param_f64(&self, key: &str) -> Option { - self.params.get(key).and_then(|v| v.as_f64()) - } - - /// Get a boolean parameter by name. - pub fn param_bool(&self, key: &str) -> Option { - self.params.get(key).and_then(|v| v.as_bool()) - } -} - -// ============================================================================ -// ToolCallback — execution callback type -// ============================================================================ - -/// Execution callback for a registered tool (synchronous). -/// -/// Receives parsed [`ToolArgs`] with typed parameters and optional stdin. -/// Return `Ok(stdout)` on success or `Err(message)` on failure. -pub type ToolCallback = Arc Result + Send + Sync>; - -/// Async execution callback for a registered tool. -/// -/// Same contract as [`ToolCallback`] but returns a `Future`, allowing -/// non-blocking I/O inside the callback. Takes owned [`ToolArgs`] because -/// the future may outlive the borrow. -pub type AsyncToolCallback = Arc< - dyn Fn(ToolArgs) -> Pin> + Send>> + Send + Sync, ->; - /// Sync or async callback for a registered tool. #[derive(Clone)] pub enum CallbackKind { /// Synchronous callback — blocks until complete. - Sync(ToolCallback), + Sync(SyncToolExec), /// Asynchronous callback — `.await`ed inside the interpreter. - Async(AsyncToolCallback), + Async(AsyncToolExec), } // ============================================================================ @@ -297,6 +189,26 @@ pub(crate) struct RegisteredTool { pub(crate) callback: CallbackKind, } +impl RegisteredTool { + /// Create from a [`ToolImpl`], converting its exec/exec_sync to a + /// [`CallbackKind`]. Prefers async when available. + pub(crate) fn from_tool_impl(tool: ToolImpl) -> Self { + let callback = if let Some(async_cb) = tool.exec { + CallbackKind::Async(async_cb) + } else if let Some(sync_cb) = tool.exec_sync { + CallbackKind::Sync(sync_cb) + } else { + // Schema-only ToolImpl — wrap as a sync callback that always errors. + let name = tool.def.name.clone(); + CallbackKind::Sync(Arc::new(move |_| Err(format!("{name}: no exec defined")))) + }; + Self { + def: tool.def, + callback, + } + } +} + // ============================================================================ // ScriptedToolBuilder // ============================================================================ @@ -308,7 +220,7 @@ pub(crate) struct RegisteredTool { /// /// let tool = ScriptedTool::builder("net") /// .short_description("Network tools") -/// .tool( +/// .tool_fn( /// ToolDef::new("ping", "Ping a host") /// .with_schema(serde_json::json!({ /// "type": "object", @@ -359,33 +271,44 @@ impl ScriptedToolBuilder { self } - /// Register a tool with its definition and synchronous execution callback. + /// Register a [`ToolImpl`] (definition + exec functions). /// - /// The callback receives [`ToolArgs`] with `--key value` flags parsed into + /// This is the preferred registration method. The `ToolImpl` carries its own + /// name, schema, and sync/async exec. + pub fn tool(mut self, tool: ToolImpl) -> Self { + self.tools.push(RegisteredTool::from_tool_impl(tool)); + self + } + + /// Register a tool with its definition and synchronous exec function. + /// + /// Convenience shorthand — constructs a [`ToolImpl`] internally. + /// The exec receives [`ToolArgs`] with `--key value` flags parsed into /// a JSON object, type-coerced per the schema. - pub fn tool( + pub fn tool_fn( mut self, def: ToolDef, - callback: impl Fn(&ToolArgs) -> Result + Send + Sync + 'static, + exec: impl Fn(&ToolArgs) -> Result + Send + Sync + 'static, ) -> Self { self.tools.push(RegisteredTool { def, - callback: CallbackKind::Sync(Arc::new(callback)), + callback: CallbackKind::Sync(Arc::new(exec)), }); self } - /// Register a tool with its definition and **async** execution callback. + /// Register a tool with its definition and **async** exec function. /// - /// Same as [`tool()`](Self::tool) but the callback returns a `Future`, + /// Convenience shorthand — constructs a [`ToolImpl`] internally. + /// Same as [`tool_fn()`](Self::tool_fn) but returns a `Future`, /// allowing non-blocking I/O. Takes owned [`ToolArgs`] because the future /// may outlive the borrow. - pub fn async_tool(mut self, def: ToolDef, callback: F) -> Self + pub fn async_tool_fn(mut self, def: ToolDef, exec: F) -> Self where F: Fn(ToolArgs) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { - let cb: AsyncToolCallback = Arc::new(move |args| Box::pin(callback(args))); + let cb: AsyncToolExec = Arc::new(move |args| Box::pin(exec(args))); self.tools.push(RegisteredTool { def, callback: CallbackKind::Async(cb), @@ -566,7 +489,7 @@ mod tests { fn build_test_tool() -> ScriptedTool { ScriptedTool::builder("test_api") .short_description("Test API") - .tool( + .tool_fn( ToolDef::new("get_user", "Fetch user by id").with_schema(serde_json::json!({ "type": "object", "properties": { @@ -580,7 +503,7 @@ mod tests { )) }, ) - .tool( + .tool_fn( ToolDef::new("get_orders", "List orders for user").with_schema(serde_json::json!({ "type": "object", "properties": { @@ -595,11 +518,11 @@ mod tests { )) }, ) - .tool( + .tool_fn( ToolDef::new("fail_tool", "Always fails"), |_args: &ToolArgs| Err("service unavailable".to_string()), ) - .tool( + .tool_fn( ToolDef::new("from_stdin", "Read from stdin, uppercase it"), |args: &ToolArgs| match args.stdin.as_deref() { Some(input) => Ok(input.to_uppercase()), @@ -621,7 +544,7 @@ mod tests { #[test] fn test_builder_default_short_description() { let tool = ScriptedTool::builder("mytools") - .tool(ToolDef::new("noop", "No-op"), |_args: &ToolArgs| { + .tool_fn(ToolDef::new("noop", "No-op"), |_args: &ToolArgs| { Ok("ok\n".to_string()) }) .build(); @@ -660,7 +583,7 @@ mod tests { #[test] fn test_system_prompt_includes_schema() { let tool = ScriptedTool::builder("schema_test") - .tool( + .tool_fn( ToolDef::new("get_user", "Fetch user by id").with_schema(serde_json::json!({ "type": "object", "properties": { @@ -690,7 +613,7 @@ mod tests { #[test] fn test_builder_contract_helpers() { let builder = ScriptedTool::builder("test_api") - .tool(ToolDef::new("ping", "Ping"), |_args: &ToolArgs| { + .tool_fn(ToolDef::new("ping", "Ping"), |_args: &ToolArgs| { Ok("pong\n".to_string()) }); let definition = builder.build_tool_definition(); @@ -708,7 +631,7 @@ mod tests { use tower::ServiceExt; let service = ScriptedTool::builder("test_api") - .tool(ToolDef::new("ping", "Ping"), |_args: &ToolArgs| { + .tool_fn(ToolDef::new("ping", "Ping"), |_args: &ToolArgs| { Ok("pong\n".to_string()) }) .build_service(); @@ -726,7 +649,7 @@ mod tests { fn test_locale_localizes_description() { let tool = ScriptedTool::builder("ua_api") .locale("uk-UA") - .tool(ToolDef::new("ping", "Ping"), |_args: &ToolArgs| { + .tool_fn(ToolDef::new("ping", "Ping"), |_args: &ToolArgs| { Ok("pong\n".to_string()) }) .build(); @@ -899,7 +822,7 @@ mod tests { async fn test_execute_with_env() { let tool = ScriptedTool::builder("env_test") .env("API_BASE", "https://api.example.com") - .tool(ToolDef::new("noop", "No-op"), |_args: &ToolArgs| { + .tool_fn(ToolDef::new("noop", "No-op"), |_args: &ToolArgs| { Ok("ok\n".to_string()) }) .build(); @@ -968,7 +891,7 @@ mod tests { #[tokio::test] async fn test_boolean_flag() { let tool = ScriptedTool::builder("bool_test") - .tool( + .tool_fn( ToolDef::new("search", "Search").with_schema(serde_json::json!({ "type": "object", "properties": { @@ -997,7 +920,7 @@ mod tests { #[tokio::test] async fn test_no_schema_treats_as_strings() { let tool = ScriptedTool::builder("str_test") - .tool( + .tool_fn( ToolDef::new("echo_args", "Echo params as JSON"), |args: &ToolArgs| Ok(format!("{}\n", args.params)), ) @@ -1031,14 +954,14 @@ mod tests { let log2 = call_log.clone(); let tool = ScriptedTool::builder("ctx_test") - .tool( + .tool_fn( ToolDef::new("tool_a", "First tool"), move |_args: &ToolArgs| { log1.lock().expect("lock").push(format!("a:{}", *s1)); Ok("a\n".to_string()) }, ) - .tool( + .tool_fn( ToolDef::new("tool_b", "Second tool"), move |_args: &ToolArgs| { log2.lock().expect("lock").push(format!("b:{}", *s2)); @@ -1066,7 +989,7 @@ mod tests { let c = counter.clone(); let tool = ScriptedTool::builder("mut_test") - .tool( + .tool_fn( ToolDef::new("increment", "Bump counter"), move |_args: &ToolArgs| { let mut count = c.lock().expect("lock"); @@ -1091,7 +1014,7 @@ mod tests { #[tokio::test] async fn test_fresh_interpreter_per_execute() { let tool = ScriptedTool::builder("isolation_test") - .tool(ToolDef::new("noop", "No-op"), |_args: &ToolArgs| { + .tool_fn(ToolDef::new("noop", "No-op"), |_args: &ToolArgs| { Ok("ok\n".to_string()) }) .build(); @@ -1123,7 +1046,7 @@ mod tests { let c = counter.clone(); let tool = ScriptedTool::builder("persist_test") - .tool( + .tool_fn( ToolDef::new("count", "Count calls"), move |_args: &ToolArgs| { let mut n = c.lock().expect("lock"); @@ -1181,7 +1104,7 @@ mod tests { #[tokio::test] async fn test_async_tool_basic() { let tool = ScriptedTool::builder("async_api") - .async_tool( + .async_tool_fn( ToolDef::new("greet", "Greet async").with_schema(serde_json::json!({ "type": "object", "properties": { "name": {"type": "string"} } @@ -1206,10 +1129,10 @@ mod tests { #[tokio::test] async fn test_mixed_sync_async_tools() { let tool = ScriptedTool::builder("mixed") - .tool(ToolDef::new("sync_ping", "Sync"), |_args: &ToolArgs| { + .tool_fn(ToolDef::new("sync_ping", "Sync"), |_args: &ToolArgs| { Ok("sync-pong\n".to_string()) }) - .async_tool( + .async_tool_fn( ToolDef::new("async_ping", "Async"), |_args: ToolArgs| async move { Ok("async-pong\n".to_string()) }, ) @@ -1230,7 +1153,7 @@ mod tests { async fn test_async_tool_error_propagates() { let tool = ScriptedTool::builder("err_api") .sanitize_errors(false) - .async_tool( + .async_tool_fn( ToolDef::new("fail", "Always fails"), |_args: ToolArgs| async move { Err("async boom".to_string()) }, ) @@ -1249,7 +1172,7 @@ mod tests { #[tokio::test] async fn test_async_tool_stdin_pipe() { let tool = ScriptedTool::builder("pipe_api") - .async_tool( + .async_tool_fn( ToolDef::new("upper", "Uppercase stdin"), |args: ToolArgs| async move { Ok(args.stdin.unwrap_or_default().to_uppercase()) }, ) @@ -1264,4 +1187,86 @@ mod tests { assert_eq!(resp.exit_code, 0); assert!(resp.stdout.contains("HELLO")); } + + // -- ToolImpl registration -- + + #[tokio::test] + async fn test_tool_impl_in_scripted_tool() { + let get_user = ToolImpl::new(ToolDef::new("get_user", "Fetch user by ID").with_schema( + serde_json::json!({ + "type": "object", + "properties": { "id": {"type": "integer"} }, + "required": ["id"] + }), + )) + .with_exec_sync(|args| { + let id = args.param_i64("id").ok_or("missing --id")?; + Ok(format!("{{\"id\":{id},\"name\":\"Alice\"}}\n")) + }); + + let tool = ScriptedTool::builder("api") + .short_description("Test API") + .tool(get_user) + .build(); + + assert!(tool.system_prompt().contains("get_user")); + assert!(tool.help().contains("get_user")); + + let resp = tool + .execute(ToolRequest { + commands: "get_user --id 42 | jq -r '.name'".to_string(), + timeout_ms: None, + }) + .await; + assert_eq!(resp.exit_code, 0); + assert_eq!(resp.stdout.trim(), "Alice"); + } + + #[tokio::test] + async fn test_tool_impl_async_exec_in_scripted_tool() { + let greet = ToolImpl::new(ToolDef::new("greet", "Greet someone").with_schema( + serde_json::json!({ + "type": "object", + "properties": { "name": {"type": "string"} } + }), + )) + .with_exec(|args| async move { + let name = args.param_str("name").unwrap_or("world"); + Ok(format!("hello {name}\n")) + }); + + let tool = ScriptedTool::builder("api").tool(greet).build(); + + let resp = tool + .execute(ToolRequest { + commands: "greet --name Bob".to_string(), + timeout_ms: None, + }) + .await; + assert_eq!(resp.exit_code, 0); + assert_eq!(resp.stdout.trim(), "hello Bob"); + } + + #[tokio::test] + async fn test_tool_impl_mixed_with_tool_fn() { + let tool_impl = ToolImpl::new(ToolDef::new("impl_cmd", "From ToolImpl")) + .with_exec_sync(|_args| Ok("from_impl\n".to_string())); + + let tool = ScriptedTool::builder("mixed") + .tool(tool_impl) + .tool_fn(ToolDef::new("fn_cmd", "From tool_fn"), |_args| { + Ok("from_fn\n".to_string()) + }) + .build(); + + let resp = tool + .execute(ToolRequest { + commands: "echo $(impl_cmd) $(fn_cmd)".to_string(), + timeout_ms: None, + }) + .await; + assert_eq!(resp.exit_code, 0); + assert!(resp.stdout.contains("from_impl")); + assert!(resp.stdout.contains("from_fn")); + } } diff --git a/crates/bashkit/src/scripted_tool/toolset.rs b/crates/bashkit/src/scripted_tool/toolset.rs index a8832aa9..9699ba6f 100644 --- a/crates/bashkit/src/scripted_tool/toolset.rs +++ b/crates/bashkit/src/scripted_tool/toolset.rs @@ -6,7 +6,7 @@ // DiscoverTool (discover/help only). use super::{ - CallbackKind, RegisteredTool, ScriptedExecutionTrace, ScriptedTool, ToolArgs, ToolDef, + CallbackKind, RegisteredTool, ScriptedExecutionTrace, ScriptedTool, ToolArgs, ToolDef, ToolImpl, }; use crate::ExecutionLimits; use crate::tool::{Tool, ToolError, ToolRequest, ToolResponse, ToolStatus, VERSION}; @@ -45,7 +45,7 @@ pub enum DiscoveryMode { /// /// # tokio_test::block_on(async { /// let toolset = ScriptingToolSet::builder("api") -/// .tool( +/// .tool_fn( /// ToolDef::new("greet", "Greet someone").with_category("social"), /// |_args: &ToolArgs| Ok("hello\n".to_string()), /// ) @@ -188,7 +188,7 @@ impl Tool for DiscoverTool { /// /// let toolset = ScriptingToolSet::builder("api") /// .short_description("Example API") -/// .tool( +/// .tool_fn( /// ToolDef::new("ping", "Ping a host"), /// |_args: &ToolArgs| Ok("pong\n".to_string()), /// ) @@ -229,26 +229,36 @@ impl ScriptingToolSetBuilder { self } - /// Register a tool with its definition and synchronous execution callback. - pub fn tool( + /// Register a [`ToolImpl`] (definition + exec functions). + pub fn tool(mut self, tool: ToolImpl) -> Self { + self.tools.push(RegisteredTool::from_tool_impl(tool)); + self + } + + /// Register a tool with its definition and synchronous exec function. + /// + /// Convenience shorthand — constructs a [`ToolImpl`] internally. + pub fn tool_fn( mut self, def: ToolDef, - callback: impl Fn(&ToolArgs) -> Result + Send + Sync + 'static, + exec: impl Fn(&ToolArgs) -> Result + Send + Sync + 'static, ) -> Self { self.tools.push(RegisteredTool { def, - callback: CallbackKind::Sync(Arc::new(callback)), + callback: CallbackKind::Sync(Arc::new(exec)), }); self } - /// Register a tool with its definition and **async** execution callback. - pub fn async_tool(mut self, def: ToolDef, callback: F) -> Self + /// Register a tool with its definition and **async** exec function. + /// + /// Convenience shorthand — constructs a [`ToolImpl`] internally. + pub fn async_tool_fn(mut self, def: ToolDef, exec: F) -> Self where F: Fn(ToolArgs) -> Fut + Send + Sync + 'static, Fut: std::future::Future> + Send + 'static, { - let cb: super::AsyncToolCallback = Arc::new(move |args| Box::pin(callback(args))); + let cb: super::AsyncToolExec = Arc::new(move |args| Box::pin(exec(args))); self.tools.push(RegisteredTool { def, callback: CallbackKind::Async(cb), @@ -301,11 +311,11 @@ impl ScriptingToolSetBuilder { match ®.callback { CallbackKind::Sync(cb) => { let cb = Arc::clone(cb); - builder = builder.tool(reg.def.clone(), move |args: &ToolArgs| (cb)(args)); + builder = builder.tool_fn(reg.def.clone(), move |args: &ToolArgs| (cb)(args)); } CallbackKind::Async(cb) => { let cb = Arc::clone(cb); - builder = builder.async_tool(reg.def.clone(), move |args: ToolArgs| { + builder = builder.async_tool_fn(reg.def.clone(), move |args: ToolArgs| { let cb = Arc::clone(&cb); async move { (cb)(args).await } }); @@ -343,7 +353,7 @@ impl ScriptingToolSetBuilder { /// # tokio_test::block_on(async { /// // Exclusive mode (default): one tool with full schemas /// let toolset = ScriptingToolSet::builder("api") -/// .tool( +/// .tool_fn( /// ToolDef::new("greet", "Greet someone") /// .with_schema(serde_json::json!({ /// "type": "object", @@ -370,7 +380,7 @@ impl ScriptingToolSetBuilder { /// /// // Discovery mode: two tools /// let toolset = ScriptingToolSet::builder("api") -/// .tool( +/// .tool_fn( /// ToolDef::new("greet", "Greet someone"), /// |_args: &ToolArgs| Ok("hello\n".to_string()), /// ) @@ -441,7 +451,7 @@ mod tests { fn make_tools() -> ScriptingToolSetBuilder { ScriptingToolSet::builder("test_api") .short_description("Test API") - .tool( + .tool_fn( ToolDef::new("get_user", "Fetch user by ID") .with_schema(serde_json::json!({ "type": "object", @@ -456,7 +466,7 @@ mod tests { Ok(format!("{{\"id\":{id},\"name\":\"Alice\"}}\n")) }, ) - .tool( + .tool_fn( ToolDef::new("list_orders", "List orders for a user") .with_schema(serde_json::json!({ "type": "object", @@ -559,7 +569,7 @@ mod tests { #[test] fn test_default_short_description() { let toolset = ScriptingToolSet::builder("mytools") - .tool(ToolDef::new("noop", "No-op"), |_: &ToolArgs| { + .tool_fn(ToolDef::new("noop", "No-op"), |_: &ToolArgs| { Ok("ok\n".into()) }) .build(); @@ -695,7 +705,7 @@ mod tests { async fn test_env_vars_passed_through() { let toolset = ScriptingToolSet::builder("env_test") .env("MY_VAR", "hello") - .tool(ToolDef::new("noop", "No-op"), |_: &ToolArgs| { + .tool_fn(ToolDef::new("noop", "No-op"), |_: &ToolArgs| { Ok("ok\n".into()) }) .build(); @@ -822,4 +832,68 @@ mod tests { Ok(_) => panic!("expected error for disallowed command"), } } + + // -- ToolImpl registration -- + + #[tokio::test] + async fn test_tool_impl_registration() { + let get_user = ToolImpl::new( + ToolDef::new("get_user", "Fetch user by ID") + .with_schema(serde_json::json!({ + "type": "object", + "properties": { "id": {"type": "integer"} }, + "required": ["id"] + })) + .with_category("users"), + ) + .with_exec_sync(|args| { + let id = args.param_i64("id").ok_or("missing --id")?; + Ok(format!("{{\"id\":{id},\"name\":\"Alice\"}}\n")) + }); + + let list_orders = ToolImpl::new( + ToolDef::new("list_orders", "List orders") + .with_schema(serde_json::json!({ + "type": "object", + "properties": { "user_id": {"type": "integer"} } + })) + .with_category("orders"), + ) + .with_exec_sync(|args| { + let uid = args.param_i64("user_id").ok_or("missing --user_id")?; + Ok(format!("[{{\"order_id\":1,\"user_id\":{uid}}}]\n")) + }); + + // Exclusive mode + let toolset = ScriptingToolSet::builder("api") + .short_description("Test API") + .tool(get_user.clone()) + .tool(list_orders.clone()) + .build(); + + let tools = toolset.tools(); + assert_eq!(tools.len(), 1); + assert!(tools[0].system_prompt().contains("get_user")); + assert!(tools[0].system_prompt().contains("list_orders")); + + let resp = tools[0] + .execute(ToolRequest { + commands: "get_user --id 1 | jq -r '.name'".into(), + timeout_ms: None, + }) + .await; + assert_eq!(resp.stdout.trim(), "Alice"); + + // Discovery mode + let toolset = ScriptingToolSet::builder("api") + .tool(get_user) + .tool(list_orders) + .with_discovery() + .build(); + + let tools = toolset.tools(); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].name(), "api"); + assert_eq!(tools[1].name(), "api_discover"); + } } diff --git a/crates/bashkit/src/tool_def.rs b/crates/bashkit/src/tool_def.rs new file mode 100644 index 00000000..01d47bcb --- /dev/null +++ b/crates/bashkit/src/tool_def.rs @@ -0,0 +1,440 @@ +// ToolDef, ToolArgs, ToolImpl — reusable tool primitives. +// +// These types live here (not in scripted_tool/) so that both Bash and +// ScriptedTool can import them without circular dependencies. +// +// Dependency direction: builtins → tool_def → {lib.rs, scripted_tool, tool.rs} + +use crate::builtins::{Builtin, Context}; +use crate::error::Result; +use crate::interpreter::ExecResult; +use async_trait::async_trait; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +// ============================================================================ +// ToolDef — OpenAPI-style tool definition (metadata only) +// ============================================================================ + +/// OpenAPI-style tool definition: name, description, input schema. +/// +/// Describes a sub-tool registered with a `ScriptedToolBuilder` or usable +/// standalone. The `input_schema` is optional JSON Schema for documentation / +/// LLM prompts and for type coercion of `--key value` flags. +#[derive(Clone)] +pub struct ToolDef { + /// Command name used as bash builtin (e.g. `"get_user"`). + pub name: String, + /// Human-readable description for LLM consumption. + pub description: String, + /// JSON Schema describing accepted arguments. Empty object if unspecified. + pub input_schema: serde_json::Value, + /// Categorical tags for discovery (e.g. `["admin", "billing"]`). + pub tags: Vec, + /// Grouping category for discovery (e.g. `"payments"`). + pub category: Option, +} + +impl ToolDef { + /// Create a tool definition with name and description. + pub fn new(name: impl Into, description: impl Into) -> Self { + Self { + name: name.into(), + description: description.into(), + input_schema: serde_json::Value::Object(Default::default()), + tags: Vec::new(), + category: None, + } + } + + /// Attach a JSON Schema for the tool's input parameters. + pub fn with_schema(mut self, schema: serde_json::Value) -> Self { + self.input_schema = schema; + self + } + + /// Add categorical tags for discovery filtering. + pub fn with_tags(mut self, tags: &[&str]) -> Self { + self.tags = tags.iter().map(|s| s.to_string()).collect(); + self + } + + /// Set the grouping category for discovery. + pub fn with_category(mut self, category: &str) -> Self { + self.category = Some(category.to_string()); + self + } +} + +// ============================================================================ +// ToolArgs — parsed arguments passed to exec functions +// ============================================================================ + +/// Parsed arguments passed to a tool exec function. +/// +/// `params` is a JSON object built from `--key value` flags, with values +/// type-coerced per the `ToolDef`'s `input_schema`. +/// `stdin` carries pipeline input from a prior command, if any. +pub struct ToolArgs { + /// Parsed parameters as a JSON object. Keys from `--key value` flags. + pub params: serde_json::Value, + /// Pipeline input from a prior command (e.g. `echo data | tool`). + pub stdin: Option, +} + +impl ToolArgs { + /// Get a string parameter by name. + pub fn param_str(&self, key: &str) -> Option<&str> { + self.params.get(key).and_then(|v| v.as_str()) + } + + /// Get an integer parameter by name. + pub fn param_i64(&self, key: &str) -> Option { + self.params.get(key).and_then(|v| v.as_i64()) + } + + /// Get a float parameter by name. + pub fn param_f64(&self, key: &str) -> Option { + self.params.get(key).and_then(|v| v.as_f64()) + } + + /// Get a boolean parameter by name. + pub fn param_bool(&self, key: &str) -> Option { + self.params.get(key).and_then(|v| v.as_bool()) + } +} + +// ============================================================================ +// Exec types — sync and async execution functions +// ============================================================================ + +/// Synchronous execution function for a tool. +/// +/// Receives parsed [`ToolArgs`] with typed parameters and optional stdin. +/// Return `Ok(stdout)` on success or `Err(message)` on failure. +pub type SyncToolExec = Arc std::result::Result + Send + Sync>; + +/// Asynchronous execution function for a tool. +/// +/// Same contract as [`SyncToolExec`] but returns a `Future`, allowing +/// non-blocking I/O. Takes owned [`ToolArgs`] because the future may +/// outlive the borrow. +pub type AsyncToolExec = Arc< + dyn Fn(ToolArgs) -> Pin> + Send>> + + Send + + Sync, +>; + +// Keep old names as aliases for backward compatibility. +/// Alias for [`SyncToolExec`] (backward compatibility). +pub type ToolCallback = SyncToolExec; +/// Alias for [`AsyncToolExec`] (backward compatibility). +pub type AsyncToolCallback = AsyncToolExec; + +// ============================================================================ +// ToolImpl — complete tool: metadata + execution +// ============================================================================ + +/// Complete tool: definition + sync/async exec functions. +/// +/// Implements [`Builtin`] so it can be registered directly in a Bash +/// interpreter or used inside a `ScriptedTool`. +/// +/// # Example +/// +/// ```rust +/// use bashkit::{ToolDef, ToolImpl}; +/// +/// let tool = ToolImpl::new( +/// ToolDef::new("greet", "Greet a user") +/// .with_schema(serde_json::json!({ +/// "type": "object", +/// "properties": { "name": {"type": "string"} } +/// })), +/// ) +/// .with_exec_sync(|args| { +/// let name = args.param_str("name").unwrap_or("world"); +/// Ok(format!("hello {name}\n")) +/// }); +/// ``` +#[derive(Clone)] +pub struct ToolImpl { + /// Tool metadata (name, description, schema, tags). + pub def: ToolDef, + /// Async exec (preferred when running in async context). + pub exec: Option, + /// Sync exec (preferred when running in sync context). + pub exec_sync: Option, +} + +impl ToolImpl { + /// Create a `ToolImpl` from a [`ToolDef`] with no exec functions. + pub fn new(def: ToolDef) -> Self { + Self { + def, + exec: None, + exec_sync: None, + } + } + + /// Set the async exec function. + pub fn with_exec(mut self, f: F) -> Self + where + F: Fn(ToolArgs) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + self.exec = Some(Arc::new(move |args| Box::pin(f(args)))); + self + } + + /// Set the sync exec function. + pub fn with_exec_sync( + mut self, + f: impl Fn(&ToolArgs) -> std::result::Result + Send + Sync + 'static, + ) -> Self { + self.exec_sync = Some(Arc::new(f)); + self + } +} + +#[async_trait] +impl Builtin for ToolImpl { + async fn execute(&self, ctx: Context<'_>) -> Result { + let params = parse_flags(ctx.args, &self.def.input_schema) + .map_err(|e| crate::error::Error::Execution(format!("{}: {e}", self.def.name)))?; + let tool_args = ToolArgs { + params, + stdin: ctx.stdin.map(String::from), + }; + + // Prefer async, fall back to sync. + let result = if let Some(cb) = &self.exec { + (cb)(tool_args).await + } else if let Some(cb) = &self.exec_sync { + (cb)(&tool_args) + } else { + return Err(crate::error::Error::Execution(format!( + "{}: no exec defined", + self.def.name + ))); + }; + + match result { + Ok(stdout) => Ok(ExecResult::ok(stdout)), + Err(msg) => Ok(ExecResult::err(msg, 1)), + } + } +} + +// ============================================================================ +// Flag parser — `--key value` / `--key=value` → JSON object +// ============================================================================ + +/// Parse `--key value` and `--key=value` flags into a JSON object. +/// Types are coerced according to the schema's property definitions. +/// Unknown flags (not in schema) are kept as strings. +/// Bare `--flag` without a value is treated as `true` if the schema says boolean, +/// otherwise as `true` when the next arg also starts with `--` or is absent. +pub(crate) fn parse_flags( + raw_args: &[String], + schema: &serde_json::Value, +) -> std::result::Result { + let properties = schema + .get("properties") + .and_then(|p| p.as_object()) + .cloned() + .unwrap_or_default(); + + let mut result = serde_json::Map::new(); + let mut i = 0; + + while i < raw_args.len() { + let arg = &raw_args[i]; + + let Some(flag) = arg.strip_prefix("--") else { + return Err(format!("expected --flag, got: {arg}")); + }; + + // --key=value + if let Some((key, raw_value)) = flag.split_once('=') { + let value = coerce_value(raw_value, properties.get(key)); + result.insert(key.to_string(), value); + i += 1; + continue; + } + + // --flag (boolean) or --key value + let key = flag; + let prop_schema = properties.get(key); + let is_boolean = prop_schema + .and_then(|s| s.get("type")) + .and_then(|t| t.as_str()) + == Some("boolean"); + + if is_boolean { + result.insert(key.to_string(), serde_json::Value::Bool(true)); + i += 1; + } else if i + 1 < raw_args.len() && !raw_args[i + 1].starts_with("--") { + let raw_value = &raw_args[i + 1]; + let value = coerce_value(raw_value, prop_schema); + result.insert(key.to_string(), value); + i += 2; + } else { + // No value follows and not boolean — treat as true + result.insert(key.to_string(), serde_json::Value::Bool(true)); + i += 1; + } + } + + Ok(serde_json::Value::Object(result)) +} + +/// Coerce a raw string value to the type declared in the property schema. +fn coerce_value(raw: &str, prop_schema: Option<&serde_json::Value>) -> serde_json::Value { + let type_str = prop_schema + .and_then(|s| s.get("type")) + .and_then(|t| t.as_str()) + .unwrap_or("string"); + + match type_str { + "integer" => raw + .parse::() + .map(serde_json::Value::from) + .unwrap_or_else(|_| serde_json::Value::String(raw.to_string())), + "number" => raw + .parse::() + .map(|n| serde_json::json!(n)) + .unwrap_or_else(|_| serde_json::Value::String(raw.to_string())), + "boolean" => match raw { + "true" | "1" | "yes" => serde_json::Value::Bool(true), + "false" | "0" | "no" => serde_json::Value::Bool(false), + _ => serde_json::Value::String(raw.to_string()), + }, + _ => serde_json::Value::String(raw.to_string()), + } +} + +/// Generate a usage hint from schema properties: `--id --name `. +pub(crate) fn usage_from_schema(schema: &serde_json::Value) -> Option { + let props = schema.get("properties")?.as_object()?; + if props.is_empty() { + return None; + } + let flags: Vec = props + .iter() + .map(|(key, prop)| { + let ty = prop.get("type").and_then(|t| t.as_str()).unwrap_or("value"); + format!("--{key} <{ty}>") + }) + .collect(); + Some(flags.join(" ")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_flags_basic() { + let schema = serde_json::json!({ + "type": "object", + "properties": { + "id": {"type": "integer"}, + "name": {"type": "string"}, + "verbose": {"type": "boolean"} + } + }); + let args = vec![ + "--id".to_string(), + "42".to_string(), + "--name".to_string(), + "Alice".to_string(), + "--verbose".to_string(), + ]; + let result = parse_flags(&args, &schema).unwrap(); + assert_eq!(result["id"], 42); + assert_eq!(result["name"], "Alice"); + assert_eq!(result["verbose"], true); + } + + #[test] + fn test_parse_flags_equals_syntax() { + let schema = serde_json::json!({ + "type": "object", + "properties": {"id": {"type": "integer"}} + }); + let args = vec!["--id=42".to_string()]; + let result = parse_flags(&args, &schema).unwrap(); + assert_eq!(result["id"], 42); + } + + #[test] + fn test_tool_impl_sync() { + let tool = ToolImpl::new(ToolDef::new("greet", "Greet a user").with_schema( + serde_json::json!({ + "type": "object", + "properties": { "name": {"type": "string"} } + }), + )) + .with_exec_sync(|args| { + let name = args.param_str("name").unwrap_or("world"); + Ok(format!("hello {name}\n")) + }); + + assert!(tool.exec_sync.is_some()); + assert!(tool.exec.is_none()); + assert_eq!(tool.def.name, "greet"); + } + + #[tokio::test] + async fn test_tool_impl_as_builtin() { + let tool = ToolImpl::new(ToolDef::new("greet", "Greet a user").with_schema( + serde_json::json!({ + "type": "object", + "properties": { "name": {"type": "string"} } + }), + )) + .with_exec_sync(|args| { + let name = args.param_str("name").unwrap_or("world"); + Ok(format!("hello {name}\n")) + }); + + // Verify it works as a Builtin + let args = vec!["--name".to_string(), "Alice".to_string()]; + let mut vars = std::collections::HashMap::new(); + let env = std::collections::HashMap::new(); + let mut cwd = std::path::PathBuf::from("/"); + let fs = Arc::new(crate::fs::InMemoryFs::new()); + let ctx = Context::new_for_test(&args, &env, &mut vars, &mut cwd, fs, None); + let result = tool.execute(ctx).await.unwrap(); + assert_eq!(result.stdout, "hello Alice\n"); + assert_eq!(result.exit_code, 0); + } + + #[tokio::test] + async fn test_tool_impl_async_exec() { + let tool = + ToolImpl::new(ToolDef::new("echo_async", "Async echo")).with_exec(|args| async move { + let msg = args.stdin.unwrap_or_default(); + Ok(format!("async: {msg}")) + }); + + assert!(tool.exec.is_some()); + assert!(tool.exec_sync.is_none()); + } + + #[tokio::test] + async fn test_tool_impl_no_exec_errors() { + let tool = ToolImpl::new(ToolDef::new("empty", "No exec")); + + let args = vec![]; + let mut vars = std::collections::HashMap::new(); + let env = std::collections::HashMap::new(); + let mut cwd = std::path::PathBuf::from("/"); + let fs = Arc::new(crate::fs::InMemoryFs::new()); + let ctx = Context::new_for_test(&args, &env, &mut vars, &mut cwd, fs, None); + let result = tool.execute(ctx).await; + assert!(result.is_err()); + } +} diff --git a/specs/014-scripted-tool-orchestration.md b/specs/014-scripted-tool-orchestration.md index bedbd01e..f3051a84 100644 --- a/specs/014-scripted-tool-orchestration.md +++ b/specs/014-scripted-tool-orchestration.md @@ -84,12 +84,52 @@ pub type AsyncToolCallback = Arc< ``` Async variant of `ToolCallback`. Takes owned `ToolArgs` (the future may outlive the -borrow). Register via `builder.async_tool(def, callback)`. Both sync and async +borrow). Register via `builder.async_tool_fn(def, callback)`. Both sync and async callbacks can be mixed in a single `ScriptedTool`. Internally represented as `CallbackKind::Async` and `.await`-ed inside `ToolBuiltinAdapter::execute()`, which is already `async fn`. +### ToolImpl — unified tool unit + +```rust +pub struct ToolImpl { + pub def: ToolDef, + pub exec: Option, + pub exec_sync: Option, +} +``` + +Combines metadata (`ToolDef`) with optional sync and async exec functions. +Implements `Builtin`, so it can be registered in both `Bash` (via `.builtin()`) +and `ScriptedTool`/`ScriptingToolSet` (via `.tool()`). + +When running async, prefers `exec`; falls back to `exec_sync`. +When running sync, prefers `exec_sync`; falls back to blocking on `exec`. + +Builder API: + +```rust +let tool = ToolImpl::new( + ToolDef::new("get_user", "Fetch user by ID") + .with_schema(json!({"type": "object", "properties": {"id": {"type": "integer"}}})), +) +.with_exec_sync(|args| { + let id = args.param_i64("id").ok_or("missing --id")?; + Ok(format!("{{\"id\":{id}}}\n")) +}) +.with_exec(|args| async move { + let id = args.param_i64("id").ok_or("missing --id")?; + Ok(format!("{{\"id\":{id}}}\n")) +}); + +// Register in ScriptedTool +ScriptedTool::builder("api").tool(tool).build(); +``` + +Type aliases for backward compatibility: `ToolCallback = SyncToolExec`, +`AsyncToolCallback = AsyncToolExec`. + ### ContextVar propagation (Python) Python callbacks (both sync and async) automatically see `contextvars.ContextVar` @@ -125,14 +165,14 @@ Unknown flags (not in schema) are kept as strings. ### ScriptedToolBuilder -Two arguments per tool: definition + callback. Use `.tool()` for sync and -`.async_tool()` for async callbacks. +Two arguments per tool: definition + callback. Use `.tool_fn()` for sync and +`.async_tool_fn()` for async callbacks. ```rust ScriptedTool::builder("api_name") .locale("en-US") .short_description("...") - .tool( + .tool_fn( ToolDef::new("get_user", "Fetch user by ID") .with_schema(json!({"type": "object", "properties": {"id": {"type": "integer"}}})), |args| { @@ -140,7 +180,7 @@ ScriptedTool::builder("api_name") Ok(format!("{{\"id\":{id}}}\n")) }, ) - .async_tool( + .async_tool_fn( ToolDef::new("fetch_url", "Fetch a URL"), |args| async move { let url = args.param_str("url").unwrap_or("?"); @@ -189,7 +229,7 @@ that lists only tool names + one-liners, deferring full schemas to `help`: ```rust ScriptedTool::builder("api") .compact_prompt(true) - .tool(...) + .tool_fn(...) .build() ``` @@ -241,7 +281,7 @@ Use the standard Rust closure-capture pattern with `Arc` to share resources: ```rust let client = Arc::new(build_authenticated_client()); let c = client.clone(); -builder.tool(ToolDef::new("get_user", "..."), move |args| { +builder.tool_fn(ToolDef::new("get_user", "..."), move |args| { let resp = c.get(&format!("/users/{}", args.param_i64("id").unwrap())); Ok(resp.text()?) }); @@ -265,7 +305,7 @@ exposes them via `take_last_execution_trace()`. This trace is for observability telemetry, not scoring: ```rust -let mut tool = ScriptedTool::builder("api").tool(...).build(); +let mut tool = ScriptedTool::builder("api").tool_fn(...).build(); let _resp = tool.execute(ToolRequest::new("discover --search user\nhelp get_user")).await; let trace = tool.take_last_execution_trace().unwrap(); assert_eq!(trace.invocations[0].name, "discover"); @@ -308,14 +348,14 @@ based on `DiscoveryMode`: // Exclusive mode (default): tools() returns [ScriptedTool] let toolset = ScriptingToolSet::builder("api") .short_description("My API") - .tool(ToolDef::new("get_user", "Fetch user").with_schema(...), callback) + .tool_fn(ToolDef::new("get_user", "Fetch user").with_schema(...), callback) .build(); let tools = toolset.tools(); // vec![ScriptedTool] // Discovery mode: tools() returns [ScriptedTool, DiscoverTool] let toolset = ScriptingToolSet::builder("api") .short_description("My API") - .tool(ToolDef::new("get_user", "Fetch user").with_category("users"), callback) + .tool_fn(ToolDef::new("get_user", "Fetch user").with_category("users"), callback) .with_discovery() .build(); let tools = toolset.tools(); // vec![ScriptedTool(compact), DiscoverTool] @@ -338,17 +378,17 @@ Builder API mirrors `ScriptedToolBuilder`: `.tool()`, `.env()`, `.limits()`, ## Module location -`crates/bashkit/src/scripted_tool/` - ``` +tool_def.rs — ToolDef, ToolArgs, ToolImpl, SyncToolExec, AsyncToolExec, parse_flags scripted_tool/ -├── mod.rs — ToolDef, ToolCallback, ScriptedToolBuilder, ScriptedTool struct, tests -├── execute.rs — Tool impl, ToolBuiltinAdapter, documentation helpers -└── toolset.rs — ScriptingToolSet, ScriptingToolSetBuilder, DiscoveryMode +├── mod.rs — CallbackKind, ScriptedToolBuilder, ScriptedTool, re-exports from tool_def +├── execute.rs — Tool impl, ToolBuiltinAdapter, documentation helpers +└── toolset.rs — ScriptingToolSet, ScriptingToolSetBuilder, DiscoveryMode ``` Public exports from `lib.rs` (gated by `scripted_tool` feature): -`ToolDef`, `ToolArgs`, `ToolCallback`, `ScriptedTool`, `ScriptedToolBuilder`, +`ToolDef`, `ToolArgs`, `ToolImpl`, `SyncToolExec`, `AsyncToolExec`, +`ToolCallback`, `AsyncToolCallback` (aliases), `ScriptedTool`, `ScriptedToolBuilder`, `ScriptingToolSet`, `ScriptingToolSetBuilder`, `DiscoverTool`, `DiscoveryMode`. ## Example