diff --git a/.gitignore b/.gitignore
index 1e47926b2..506931c98 100644
--- a/.gitignore
+++ b/.gitignore
@@ -58,12 +58,14 @@ docs/example1.dat
docs/example3.dat
python/.eggs/
python/doc/
+python/examples/.ipynb_checkpoints
# Egg metadata
*.egg-info
.vscode
.idea/
.pytest_cache/
+.ruff_cache/
pkgs
docker_cache
.gdb_history
diff --git a/python/README.md b/python/README.md
index 0a26ee979..8ba7d1f27 100644
--- a/python/README.md
+++ b/python/README.md
@@ -79,6 +79,102 @@ pyarrow_batches = df.collect()
Check [DataFusion python](https://datafusion.apache.org/python/) provides more examples and manuals.
+## Jupyter Notebook Support
+
+PyBallista provides first-class Jupyter notebook support with SQL magic commands and rich HTML rendering.
+
+### Install Jupyter extras first:
+```bash
+pip install "ballista[jupyter]"
+```
+
+### HTML Table Rendering
+
+DataFrames automatically render as styled HTML tables in Jupyter notebooks:
+
+```python
+from ballista import BallistaSessionContext
+
+ctx = BallistaSessionContext("df://localhost:50050")
+df = ctx.sql("SELECT * FROM my_table LIMIT 10")
+df # Renders as HTML table via _repr_html_()
+```
+
+### SQL Magic Commands
+
+For a more interactive SQL experience, load the Ballista Jupyter extension:
+
+```python
+# Load the extension
+%load_ext ballista.jupyter
+
+# Connect to a Ballista cluster
+%ballista connect df://localhost:50050
+
+# Register .parquet table
+%register parquet public.test_data_v1 ../testdata/test.parquet
+
+# Check connection status
+%ballista status
+
+# List registered tables
+%ballista tables
+
+# Show table schema
+%ballista schema my_table
+
+# Execute a simple query (line magic)
+%sql SELECT COUNT(*) FROM orders
+
+# Execute a complex query (cell magic)
+%%sql
+SELECT
+ customer_id,
+ SUM(amount) as total
+FROM orders
+GROUP BY customer_id
+ORDER BY total DESC
+LIMIT 10
+```
+
+You can also store results in a variable:
+
+```python
+%%sql my_result
+SELECT * FROM orders WHERE status = 'pending'
+```
+
+### Execution Plan Visualization
+
+Visualize query execution plans directly in notebooks:
+
+```python
+df = ctx.sql("SELECT * FROM orders WHERE amount > 100")
+df.explain_visual() # Displays SVG visualization
+
+# With runtime statistics
+df.explain_visual(analyze=True)
+```
+
+> **Note:** Full SVG visualization requires graphviz to be installed (`brew install graphviz` on macOS).
+
+### Progress Indicators
+
+For long-running queries, use `collect_with_progress()` to see execution status:
+
+```python
+df = ctx.sql("SELECT * FROM large_table")
+batches = df.collect_with_progress()
+```
+
+### Example Notebooks
+
+See the `examples/` directory for Jupyter notebooks demonstrating various features:
+
+- `getting_started.ipynb` - Basic connection and queries
+- `dataframe_api.ipynb` - DataFrame transformations
+- `distributed_queries.ipynb` - Multi-stage distributed query examples
+
## Scheduler and Executor
Scheduler and executors can be configured and started from python code.
diff --git a/python/examples/dataframe_api.ipynb b/python/examples/dataframe_api.ipynb
new file mode 100644
index 000000000..aaa86df75
--- /dev/null
+++ b/python/examples/dataframe_api.ipynb
@@ -0,0 +1,419 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "\n",
+ "# DataFrame API with Ballista\n",
+ "\n",
+ "This notebook demonstrates the DataFrame API available in Ballista.\n",
+ "\n",
+ "The DataFrame API provides a programmatic way to build queries, which can be\n",
+ "more convenient than writing SQL for complex transformations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from ballista import BallistaSessionContext, setup_test_cluster\n",
+ "from datafusion import col, lit\n",
+ "from datafusion import functions as f\n",
+ "\n",
+ "# Set up test cluster and connect\n",
+ "host, port = setup_test_cluster()\n",
+ "ctx = BallistaSessionContext(f\"df://{host}:{port}\")\n",
+ "\n",
+ "# Register sample data\n",
+ "ctx.register_parquet(\"test_data\", \"../testdata/test.parquet\")\n",
+ "ctx.register_csv(\"csv_data\", \"../testdata/test.csv\", has_header=True)\n",
+ "\n",
+ "print(f\"Connected! Session ID: {ctx.session_id}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Basic Operations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Read a table as a DataFrame\n",
+ "df = ctx.table(\"test_data\")\n",
+ "\n",
+ "# Display schema\n",
+ "print(\"Schema:\")\n",
+ "for field in df.schema():\n",
+ " print(f\" {field.name}: {field.type}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Show first few rows\n",
+ "df.show(5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Selecting Columns"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Select specific columns by name\n",
+ "df.select(\"id\", \"bool_col\", \"tinyint_col\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Select with column expressions\n",
+ "df.select(\n",
+ " col(\"id\"),\n",
+ " col(\"tinyint_col\").alias(\"tiny\"),\n",
+ " (col(\"id\") * lit(10)).alias(\"id_times_10\")\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Filtering Data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Simple filter\n",
+ "df.filter(col(\"id\") > lit(4))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Complex filter with AND/OR\n",
+ "df.filter(\n",
+ " (col(\"id\") >= lit(2)) & (col(\"id\") <= lit(5))\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Filter with boolean column\n",
+ "df.filter(col(\"bool_col\") == lit(True))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Sorting"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Sort ascending\n",
+ "df.sort(col(\"id\").sort(ascending=True))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Sort descending\n",
+ "df.sort(col(\"id\").sort(ascending=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Limiting Results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Limit number of rows\n",
+ "df.limit(3)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Aggregations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Count all rows\n",
+ "result = df.aggregate([], [f.count_star().alias(\"total_count\")])\n",
+ "result"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Group by and aggregate\n",
+ "df.aggregate(\n",
+ " [col(\"bool_col\")],\n",
+ " [\n",
+ " f.count_star().alias(\"count\"),\n",
+ " f.sum(col(\"id\")).alias(\"sum_id\"),\n",
+ " f.avg(col(\"id\")).alias(\"avg_id\"),\n",
+ " ]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Distinct Values"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Get distinct values\n",
+ "df.select(\"bool_col\").distinct()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Chaining Operations\n",
+ "\n",
+ "DataFrame operations can be chained together to build complex transformations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Complex chained query\n",
+ "result = (\n",
+ " ctx.table(\"test_data\")\n",
+ " .select(\"id\", \"bool_col\", \"tinyint_col\")\n",
+ " .filter(col(\"id\") > lit(2))\n",
+ " .sort(col(\"id\").sort(ascending=False))\n",
+ " .limit(5)\n",
+ ")\n",
+ "\n",
+ "result"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# View the execution plan for the chained query\n",
+ "print(result.explain())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Visual execution plan\n",
+ "result.explain_visual()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Collecting Results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Collect as Arrow batches\n",
+ "batches = result.collect()\n",
+ "print(f\"Got {len(batches)} batch(es)\")\n",
+ "print(f\"Total rows: {sum(len(batch) for batch in batches)}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Collect as Arrow table\n",
+ "table = result.to_arrow_table()\n",
+ "print(f\"Arrow table: {table.num_rows} rows, {table.num_columns} columns\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Convert to Pandas\n",
+ "pdf = result.to_pandas()\n",
+ "pdf"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Get the count without collecting all data\n",
+ "count = ctx.table(\"test_data\").count()\n",
+ "print(f\"Total rows in test_data: {count}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Working with CSV Data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Read CSV directly without registering\n",
+ "csv_df = ctx.read_csv(\"../testdata/test.csv\", has_header=True)\n",
+ "\n",
+ "# Show schema and data\n",
+ "print(\"CSV Schema:\")\n",
+ "for field in csv_df.schema():\n",
+ " print(f\" {field.name}: {field.type}\")\n",
+ "\n",
+ "csv_df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Filter CSV data\n",
+ "csv_df.filter(col(\"a\") > lit(2))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Next Steps\n",
+ "\n",
+ "- See `distributed_queries.ipynb` for examples of distributed query execution\n",
+ "- Check the [DataFusion Python documentation](https://datafusion.apache.org/python/) for more DataFrame operations\n",
+ "- Review the SQL magic commands in `getting_started.ipynb` for interactive querying"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/python/examples/distributed_queries.ipynb b/python/examples/distributed_queries.ipynb
new file mode 100644
index 000000000..2c3c14902
--- /dev/null
+++ b/python/examples/distributed_queries.ipynb
@@ -0,0 +1,411 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "\n",
+ "# Distributed Queries with Ballista\n",
+ "\n",
+ "This notebook demonstrates distributed query execution features in Ballista.\n",
+ "\n",
+ "## Overview\n",
+ "\n",
+ "Ballista is a distributed query engine that can execute queries across multiple\n",
+ "nodes. When you submit a query, Ballista:\n",
+ "\n",
+ "1. Parses and optimizes the query\n",
+ "2. Creates a distributed execution plan\n",
+ "3. Distributes work across executors\n",
+ "4. Collects and returns results\n",
+ "\n",
+ "This enables processing of datasets much larger than a single machine's memory."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from ballista import BallistaSessionContext, setup_test_cluster\n",
+ "from datafusion import col, lit\n",
+ "from datafusion import functions as f\n",
+ "\n",
+ "# Set up test cluster and connect\n",
+ "host, port = setup_test_cluster()\n",
+ "ctx = BallistaSessionContext(f\"df://{host}:{port}\")\n",
+ "\n",
+ "# Register sample data\n",
+ "ctx.register_parquet(\"test_data\", \"../testdata/test.parquet\")\n",
+ "\n",
+ "print(f\"Connected! Session ID: {ctx.session_id}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Execution Plans\n",
+ "\n",
+ "Understanding execution plans is key to optimizing distributed queries."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create a query with multiple stages\n",
+ "df = ctx.sql(\"\"\"\n",
+ " SELECT \n",
+ " bool_col,\n",
+ " COUNT(*) as cnt,\n",
+ " SUM(id) as sum_id,\n",
+ " AVG(tinyint_col) as avg_tiny\n",
+ " FROM test_data\n",
+ " WHERE id > 2\n",
+ " GROUP BY bool_col\n",
+ " ORDER BY cnt DESC\n",
+ "\"\"\")\n",
+ "\n",
+ "df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# View the logical plan\n",
+ "print(\"Logical Plan:\")\n",
+ "print(df.explain())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Visualize the execution plan\n",
+ "# This shows the query plan as a graph (requires graphviz for full SVG)\n",
+ "df.explain_visual()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# View the plan with runtime statistics (analyze=True runs the query)\n",
+ "print(\"Analyzed Plan (with statistics):\")\n",
+ "print(df.explain(analyze=True))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Multi-Stage Queries\n",
+ "\n",
+ "Complex queries may involve multiple stages of distributed execution."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Subquery example\n",
+ "result = ctx.sql(\"\"\"\n",
+ " WITH stats AS (\n",
+ " SELECT \n",
+ " bool_col,\n",
+ " COUNT(*) as cnt\n",
+ " FROM test_data\n",
+ " GROUP BY bool_col\n",
+ " )\n",
+ " SELECT \n",
+ " t.id,\n",
+ " t.bool_col,\n",
+ " s.cnt as group_count\n",
+ " FROM test_data t\n",
+ " JOIN stats s ON t.bool_col = s.bool_col\n",
+ " WHERE t.id <= 5\n",
+ " ORDER BY t.id\n",
+ "\"\"\")\n",
+ "\n",
+ "result"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# View the execution plan - notice the join and exchange stages\n",
+ "result.explain_visual()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## DataFrame API for Complex Transformations"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Build a complex transformation using the DataFrame API\n",
+ "df1 = ctx.table(\"test_data\")\n",
+ "\n",
+ "# Aggregate to get group statistics\n",
+ "group_stats = df1.aggregate(\n",
+ " [col(\"bool_col\")],\n",
+ " [\n",
+ " f.count_star().alias(\"group_count\"),\n",
+ " f.avg(col(\"id\")).alias(\"avg_id\"),\n",
+ " ]\n",
+ ")\n",
+ "\n",
+ "group_stats"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Join original data with statistics\n",
+ "joined = df1.join(\n",
+ " group_stats,\n",
+ " on=\"bool_col\",\n",
+ " how=\"inner\"\n",
+ ")\n",
+ "\n",
+ "joined.select(\"id\", \"bool_col\", \"group_count\", \"avg_id\").limit(10)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Window Functions\n",
+ "\n",
+ "Window functions allow computations across related rows."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Window function example\n",
+ "window_result = ctx.sql(\"\"\"\n",
+ " SELECT \n",
+ " id,\n",
+ " bool_col,\n",
+ " tinyint_col,\n",
+ " SUM(tinyint_col) OVER (\n",
+ " PARTITION BY bool_col \n",
+ " ORDER BY id\n",
+ " ) as running_sum,\n",
+ " ROW_NUMBER() OVER (\n",
+ " PARTITION BY bool_col \n",
+ " ORDER BY id\n",
+ " ) as row_num\n",
+ " FROM test_data\n",
+ " ORDER BY bool_col, id\n",
+ "\"\"\")\n",
+ "\n",
+ "window_result"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Progress Tracking for Long Queries\n",
+ "\n",
+ "For long-running queries, you can track progress."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Execute a query with progress tracking\n",
+ "df = ctx.sql(\"SELECT * FROM test_data\")\n",
+ "\n",
+ "# collect_with_progress shows elapsed time in Jupyter\n",
+ "batches = df.collect_with_progress()\n",
+ "print(f\"Collected {len(batches)} batch(es)\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# You can also provide a custom callback\n",
+ "def my_progress_callback(status, progress):\n",
+ " if progress < 0:\n",
+ " print(f\"Status: {status} (in progress...)\")\n",
+ " else:\n",
+ " print(f\"Status: {status} ({progress:.0%} complete)\")\n",
+ "\n",
+ "df = ctx.sql(\"SELECT * FROM test_data\")\n",
+ "batches = df.collect_with_progress(callback=my_progress_callback)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Writing Results\n",
+ "\n",
+ "Distributed write operations for large result sets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Prepare a query result\n",
+ "df = ctx.sql(\"\"\"\n",
+ " SELECT \n",
+ " id,\n",
+ " bool_col,\n",
+ " tinyint_col * 2 as doubled\n",
+ " FROM test_data\n",
+ " WHERE id > 3\n",
+ "\"\"\")\n",
+ "\n",
+ "df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Write to Parquet (distributed write)\n",
+ "# df.write_parquet(\"../target/output.parquet\")\n",
+ "\n",
+ "# Write to CSV\n",
+ "# df.write_csv(\"../target/output.csv\")\n",
+ "\n",
+ "# Write to JSON\n",
+ "# df.write_json(\"../target/output.json\")\n",
+ "\n",
+ "print(\"Write operations are commented out - uncomment to test\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Best Practices for Distributed Queries\n",
+ "\n",
+ "1. **Filter early**: Push filters as close to the data source as possible\n",
+ "2. **Project early**: Select only needed columns to reduce data movement\n",
+ "3. **Partition wisely**: Ensure data is partitioned for efficient joins\n",
+ "4. **Check plans**: Use `explain()` and `explain_visual()` to understand execution\n",
+ "5. **Monitor progress**: Use `collect_with_progress()` for long queries"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Example: Optimized query pattern\n",
+ "optimized = (\n",
+ " ctx.table(\"test_data\")\n",
+ " # 1. Filter early\n",
+ " .filter(col(\"id\") > lit(2))\n",
+ " # 2. Project only needed columns\n",
+ " .select(\"id\", \"bool_col\", \"tinyint_col\")\n",
+ " # 3. Aggregate\n",
+ " .aggregate(\n",
+ " [col(\"bool_col\")],\n",
+ " [f.count_star().alias(\"cnt\")]\n",
+ " )\n",
+ ")\n",
+ "\n",
+ "# 4. Check the plan\n",
+ "print(\"Optimized plan:\")\n",
+ "print(optimized.explain())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Next Steps\n",
+ "\n",
+ "- Review the [Ballista Architecture docs](https://datafusion.apache.org/ballista/)\n",
+ "- Learn about cluster deployment and configuration\n",
+ "- Explore advanced features like custom functions and plugins"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/python/examples/getting_started.ipynb b/python/examples/getting_started.ipynb
new file mode 100644
index 000000000..b2f725cba
--- /dev/null
+++ b/python/examples/getting_started.ipynb
@@ -0,0 +1,329 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "\n",
+ "# Getting Started with PyBallista\n",
+ "\n",
+ "This notebook demonstrates how to get started with Ballista using Python.\n",
+ "\n",
+ "## Prerequisites\n",
+ "\n",
+ "1. Install PyBallista: `pip install ballista`\n",
+ "2. Have a Ballista cluster running (or use the built-in test cluster)\n",
+ "\n",
+ "## Overview\n",
+ "\n",
+ "Ballista is a distributed query engine built on Apache DataFusion. PyBallista provides:\n",
+ "\n",
+ "- **BallistaSessionContext**: Drop-in replacement for DataFusion's SessionContext\n",
+ "- **SQL Magic Commands**: Interactive SQL in Jupyter notebooks via `%sql` and `%%sql`\n",
+ "- **DataFrame API**: Full DataFrame API for data transformations\n",
+ "- **Rich HTML Display**: DataFrames render as styled HTML tables"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Method 1: Python API\n",
+ "\n",
+ "The most straightforward way to use Ballista is via the Python API."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import ballista\n",
+ "from ballista import BallistaSessionContext, setup_test_cluster\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "# Check versions\n",
+ "print(f\"Ballista version: {ballista.__version__}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# For this demo, we'll use the built-in test cluster\n",
+ "# In production, you would connect to your Ballista scheduler:\n",
+ "# ctx = BallistaSessionContext(\"df://your-scheduler:50050\")\n",
+ "\n",
+ "host, port = setup_test_cluster()\n",
+ "ctx = BallistaSessionContext(f\"df://{host}:{port}\")\n",
+ "\n",
+ "print(f\"Connected to Ballista at {host}:{port}\")\n",
+ "print(f\"Session ID: {ctx.session_id}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Register two .parquet files at the same db schema\n",
+ "ctx.register_parquet(\"public.test_data_v1\", \"../testdata/test.parquet\")\n",
+ "ctx.register_parquet(\"public.test_data_v2\", \"../testdata/test.parquet\")\n",
+ "\n",
+ "# List registered tables\n",
+ "print(\"Registered schemas and tables:\\n\", ctx.get_tables())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Execute a SQL query - the DataFrame will render as a nice HTML table\n",
+ "df = ctx.sql(\"SELECT * FROM public.test_data_v1 LIMIT 10\")\n",
+ "df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# You can also use show() for terminal-style output\n",
+ "df.show(5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Get the execution plan\n",
+ "print(df.explain())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Visualize the execution plan (requires graphviz for full SVG)\n",
+ "df.explain_visual()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Method 2: SQL Magic Commands\n",
+ "\n",
+ "For a more interactive experience, use the SQL magic commands!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Load the Ballista Jupyter extension\n",
+ "%reload_ext ballista.jupyter"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Connect to the cluster\n",
+ "%ballista connect df://localhost:39431"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Register .parquet table\n",
+ "%register parquet public.test_data_v1 ../testdata/test.parquet"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Check connection status\n",
+ "%ballista status"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# List registered tables\n",
+ "%ballista tables"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Execute a single-line SQL query\n",
+ "%sql select count(*) as count_d from public.test_data_v1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%sql\n",
+ "-- Multi-line queries work with %%sql cell magic\n",
+ "SELECT\n",
+ " id,\n",
+ " bool_col,\n",
+ " tinyint_col\n",
+ "FROM test_data_v1\n",
+ "WHERE id > 2\n",
+ "ORDER BY id\n",
+ "LIMIT 5"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%sql my_result\n",
+ "-- Store the result in a variable for further processing\n",
+ "SELECT * FROM test_data_v1 WHERE id <= 3"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# View query history\n",
+ "%ballista history"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%ballista help"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Data Export\n",
+ "\n",
+ "Ballista supports exporting data in multiple formats."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df = ctx.sql(\"SELECT * FROM test_data_v1 LIMIT 100\")\n",
+ "\n",
+ "# Export to various formats\n",
+ "# df.write_parquet(\"output.parquet\")\n",
+ "# df.write_csv(\"output.csv\")\n",
+ "# df.write_json(\"output.json\")\n",
+ "\n",
+ "# Convert to Arrow, Pandas, or Polars\n",
+ "arrow_table = df.to_arrow_table()\n",
+ "print(f\"Arrow Table Schema:\\n{arrow_table.schema}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Convert to pandas\n",
+ "pandas_df = df.to_pandas()\n",
+ "pandas_df.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Next Steps\n",
+ "\n",
+ "- Check out the `dataframe_api.ipynb` notebook for more DataFrame operations\n",
+ "- See `distributed_queries.ipynb` for examples of distributed query execution\n",
+ "- Read the [PyBallista documentation](https://datafusion.apache.org/ballista/) for more details"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/python/pyproject.toml b/python/pyproject.toml
index bce9f5d11..84f36c7eb 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -46,6 +46,11 @@ dependencies = [
]
dynamic = ["version"]
+[project.optional-dependencies]
+jupyter = [
+ "ipython>=8.0.0",
+]
+
[project.urls]
homepage = "https://datafusion.apache.org/ballista"
documentation = "https://datafusion.apache.org/ballista"
diff --git a/python/python/ballista/__init__.py b/python/python/ballista/__init__.py
index 0c3aaa75a..5a5ef7bb0 100644
--- a/python/python/ballista/__init__.py
+++ b/python/python/ballista/__init__.py
@@ -25,7 +25,11 @@
BallistaExecutor,
setup_test_cluster,
)
-from .extension import BallistaSessionContext
+from .extension import (
+ BallistaSessionContext,
+ DistributedDataFrame,
+ ExecutionPlanVisualization,
+)
__version__ = importlib_metadata.version(__name__)
@@ -34,4 +38,6 @@
"BallistaScheduler",
"BallistaExecutor",
"BallistaSessionContext",
+ "DistributedDataFrame",
+ "ExecutionPlanVisualization",
]
diff --git a/python/python/ballista/extension.py b/python/python/ballista/extension.py
index e1217e845..87f2d6d66 100644
--- a/python/python/ballista/extension.py
+++ b/python/python/ballista/extension.py
@@ -19,8 +19,12 @@
from datafusion.dataframe import Compression
from typing import (
+ List,
Union,
+ Optional,
+ Callable
)
+import warnings
from ._internal_ballista import create_ballista_data_frame
@@ -28,26 +32,6 @@
from ._internal_ballista import ParquetWriterOptions as ParquetWriterOptionsInternal
import pathlib
-# DataFrame execution methods which should be automatically
-# overridden.
-
-OVERRIDDEN_EXECUTION_METHODS = [
- "show",
- "count",
- "collect",
- "collect_partitioned",
- "write_json",
- "to_arrow_table",
- "to_pandas",
- "to_pydict",
- "to_polars",
- "to_pylist",
- "_repr_html_",
- "execute_stream",
- "execute_stream_partitioned",
-]
-
-
# class used to redefine DataFrame object
# intercepting execution methods and methods
# which returns `DataFrame`
@@ -90,10 +74,6 @@ def method_wrapper(*args, **kwargs):
#
attrs[base_name] = __wrap_dataframe_result(base_value)
- # TODO: we could do better here
- for function in OVERRIDDEN_EXECUTION_METHODS:
- attrs[function] = __wrap_dataframe_execution(function)
-
return super().__new__(cls, name, bases, attrs)
@@ -131,21 +111,18 @@ def method_wrapper(*args, **kwargs):
# serialize it and invoke ballista client to execute it
#
# this class keeps reference to remote ballista
-
-
-class DistributedDataFrame(DataFrame, metaclass=RedefiningDataFrameMeta):
+class DistributedDataFrame(DataFrame, metaclass=type):
def __init__(self, df: DataFrame, session_id: str, address: str):
super().__init__(df.df)
self.address = address
- self.session_id = session_id
-
+ self._session_id = session_id
#
# this will create a ballista dataframe, which has ballista
# session context, and ballista planner.
#
def _to_internal_df(self):
blob_plan = self.logical_plan().to_proto()
- df = create_ballista_data_frame(blob_plan, self.address, self.session_id)
+ df = create_ballista_data_frame(blob_plan, self.address, self._session_id)
return df
def write_csv(self, path, with_header=False):
@@ -229,9 +206,325 @@ def write_parquet(
df = self._to_internal_df()
df.write_parquet(str(path), compression.value, compression_level)
+ def explain_visual(self, analyze: bool = False) -> "ExecutionPlanVisualization":
+ """
+ Generate a visual representation of the execution plan.
+
+ This method creates an SVG visualization of the query execution plan,
+ which can be displayed directly in Jupyter notebooks.
+
+ Args:
+ analyze: If True, includes runtime statistics from actual execution.
+
+ Returns:
+ ExecutionPlanVisualization: An object that renders as SVG in Jupyter.
+
+ Example:
+ >>> df = ctx.sql("SELECT * FROM orders WHERE amount > 100")
+ >>> df.explain_visual() # Displays SVG in notebook
+ >>> viz = df.explain_visual(analyze=True)
+ >>> viz.save("plan.svg") # Save to file
+ """
+ # Get the execution plan as a string representation
+ # Note: explain() prints but doesn't return a string, so we use logical_plan()
+ try:
+ plan = self.logical_plan()
+ plan_str = plan.display_indent()
+ except Exception:
+ # Fallback if logical_plan() fails
+ plan_str = "Unable to retrieve execution plan"
+ return ExecutionPlanVisualization(plan_str, analyze=analyze)
+
+ def collect_with_progress(
+ self,
+ callback: Optional[Callable] = None,
+ poll_interval: float = 0.5,
+ ):
+ """
+ Collect results with progress indication.
+
+ For long-running queries, this method provides progress updates
+ through a callback function or displays a progress bar in Jupyter.
+
+ Args:
+ callback: Optional function to call with progress updates.
+ Signature: callback(status: str, progress: float)
+ poll_interval: How often to check progress (seconds).
+
+ Returns:
+ The collected result batches.
+
+ Example:
+ >>> def my_callback(status, progress):
+ ... print(f"{status}: {progress:.1%}")
+ >>> batches = df.collect_with_progress(callback=my_callback)
+ """
+ import threading
+ import time
+
+ result = [None]
+ error = [None]
+ done = threading.Event()
+
+ def execute():
+ try:
+ result[0] = self.collect()
+ except Exception as e:
+ error[0] = e
+ finally:
+ done.set()
+
+ thread = threading.Thread(target=execute)
+ thread.start()
+
+ # Check if we're in a Jupyter environment
+ try:
+ from IPython.display import clear_output
+ from IPython.core.getipython import get_ipython
+
+ in_jupyter = get_ipython() is not None
+ except (ImportError, AttributeError):
+ in_jupyter = False
+
+ start_time = time.time()
+
+ if in_jupyter and callback is None:
+ # Display a simple progress indicator
+ try:
+ while not done.wait(timeout=poll_interval):
+ elapsed = time.time() - start_time
+ clear_output(wait=True)
+ print(f"⏳ Query executing... ({elapsed:.1f}s elapsed)")
+
+ clear_output(wait=True)
+ elapsed = time.time() - start_time
+ print(f"✓ Query completed in {elapsed:.1f}s")
+ except Exception:
+ pass # Ignore display errors
+ elif callback is not None:
+ while not done.wait(timeout=poll_interval):
+ elapsed = time.time() - start_time
+ callback(f"Executing ({elapsed:.1f}s)", -1.0) # -1 means indeterminate
+
+ elapsed = time.time() - start_time
+ callback(f"Completed in {elapsed:.1f}s", 1.0)
+ else:
+ done.wait()
+
+ thread.join()
+
+ if error[0] is not None:
+ raise error[0]
+
+ return result[0]
+
+
+class ExecutionPlanVisualization:
+ """
+ A wrapper for execution plan visualizations that can render as SVG in Jupyter.
+
+ This class takes the text representation of an execution plan and converts
+ it to a Graphviz DOT format, which is then rendered as SVG.
+ """
+
+ def __init__(self, plan_str: str, analyze: bool = False):
+ self.plan_str = plan_str
+ self.analyze = analyze
+ self._svg_cache: Optional[str] = None
+
+ def _parse_plan_to_dot(self) -> str:
+ """Convert the plan string to DOT format for Graphviz."""
+ lines = self.plan_str.strip().split("\n")
+
+ dot_lines = [
+ "digraph ExecutionPlan {",
+ ' rankdir=TB;',
+ ' node [shape=box, style="rounded,filled", fontname="Helvetica"];',
+ ' edge [fontname="Helvetica"];',
+ "",
+ ]
+
+ nodes = []
+ edges = []
+ node_id = 0
+ stack = [] # (indent_level, node_id)
+
+ for line in lines:
+ if not line.strip():
+ continue
+
+ # Calculate indent level
+ indent = len(line) - len(line.lstrip())
+ content = line.strip()
+
+ # Skip non-plan lines
+ if content.startswith("physical_plan") or content.startswith("logical_plan"):
+ continue
+
+ # Create a node for this plan element
+ current_id = node_id
+ node_id += 1
+
+ # Determine node color based on operation type
+ color = "#E3F2FD" # Default light blue
+ if "Scan" in content or "TableScan" in content:
+ color = "#E8F5E9" # Light green for scans
+ elif "Filter" in content:
+ color = "#FFF3E0" # Light orange for filters
+ elif "Aggregate" in content or "HashAggregate" in content:
+ color = "#F3E5F5" # Light purple for aggregations
+ elif "Join" in content:
+ color = "#FFEBEE" # Light red for joins
+ elif "Sort" in content:
+ color = "#E0F7FA" # Light cyan for sorts
+ elif "Projection" in content:
+ color = "#FFF8E1" # Light amber for projections
+
+ # Escape special characters for DOT format
+ label = content.replace('"', '\\"').replace("\n", "\\n")
+ if len(label) > 60:
+ # Wrap long labels
+ label = label[:57] + "..."
+
+ nodes.append(f' node{current_id} [label="{label}", fillcolor="{color}"];')
+
+ # Connect to parent based on indentation
+ while stack and stack[-1][0] >= indent:
+ stack.pop()
+
+ if stack:
+ parent_id = stack[-1][1]
+ edges.append(f" node{parent_id} -> node{current_id};")
+
+ stack.append((indent, current_id))
+
+ dot_lines.extend(nodes)
+ dot_lines.append("")
+ dot_lines.extend(edges)
+ dot_lines.append("}")
+
+ return "\n".join(dot_lines)
+
+ def to_dot(self) -> str:
+ """Get the DOT representation of the execution plan."""
+ return self._parse_plan_to_dot()
+
+ def to_svg(self) -> str:
+ """
+ Convert the execution plan to SVG format.
+
+ Requires graphviz to be installed. If graphviz is not available,
+ returns a simple HTML representation instead.
+ """
+ if self._svg_cache is not None:
+ return self._svg_cache
+
+ dot_source = self._parse_plan_to_dot()
+
+ try:
+ import subprocess
+
+ # Try to use graphviz's dot command
+ process = subprocess.run(
+ ["dot", "-Tsvg"],
+ input=dot_source.encode(),
+ capture_output=True,
+ timeout=30,
+ )
+
+ if process.returncode == 0:
+ self._svg_cache = process.stdout.decode()
+ return self._svg_cache
+ except (subprocess.SubprocessError, FileNotFoundError, subprocess.TimeoutExpired) as e:
+ warnings.warn(f"Could not convert the execution plan to SVG format: {e}")
+ pass
+
+ # Fallback: return a pre-formatted HTML representation
+ escaped_plan = (
+ self.plan_str.replace("&", "&")
+ .replace("<", "<")
+ .replace(">", ">")
+ )
+ self._svg_cache = f"""
+
+
+ Execution Plan {'(with statistics)' if self.analyze else ''}
+
Install graphviz for visual diagram: brew install graphviz
+
+
{escaped_plan}
+
+ """
+ return self._svg_cache
+
+ def save(self, path: str) -> None:
+ """Save the visualization to a file (SVG or DOT format)."""
+ if path.endswith(".dot"):
+ content = self.to_dot()
+ else:
+ content = self.to_svg()
+
+ with open(path, "w") as f:
+ f.write(content)
+
+ def _repr_html_(self) -> str:
+ """HTML representation for Jupyter notebooks."""
+ return self.to_svg()
+
+ def _repr_svg_(self) -> str:
+ """SVG representation for Jupyter notebooks."""
+ svg = self.to_svg()
+ # Only return if it's actual SVG content
+ if svg.strip().startswith("