diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc index 01780f0efe2..4868145370c 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc @@ -890,8 +890,13 @@ SQLRETURN SQLNumResultCols(SQLHSTMT stmt, SQLSMALLINT* column_count_ptr) { ARROW_LOG(DEBUG) << "SQLNumResultCols called with stmt: " << stmt << ", column_count_ptr: " << static_cast(column_count_ptr); - // GH-47713 TODO: Implement SQLNumResultCols - return SQL_INVALID_HANDLE; + + using ODBC::ODBCStatement; + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + statement->GetColumnCount(column_count_ptr); + return SQL_SUCCESS; + }); } SQLRETURN SQLRowCount(SQLHSTMT stmt, SQLLEN* row_count_ptr) { diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc index d452e77db1d..f42a455fdc2 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.cc @@ -735,6 +735,16 @@ bool ODBCStatement::GetData(SQLSMALLINT record_number, SQLSMALLINT c_type, data_ptr, buffer_length, indicator_ptr); } +void ODBCStatement::GetColumnCount(SQLSMALLINT* column_count_ptr) { + if (!column_count_ptr) { + // column count pointer is not valid, do nothing as ODBC spec does not mention this as + // an error + return; + } + size_t column_count = ird_->GetRecords().size(); + *column_count_ptr = static_cast(column_count); +} + void ODBCStatement::ReleaseStatement() { CloseCursor(true); connection_.DropStatement(this); diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h index 8e128db1bda..192028578d5 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_statement.h @@ -80,6 +80,9 @@ class ODBCStatement : public ODBCHandle { bool GetData(SQLSMALLINT record_number, SQLSMALLINT c_type, SQLPOINTER data_ptr, SQLLEN buffer_length, SQLLEN* indicator_ptr); + /// \brief Return number of columns from data set + void GetColumnCount(SQLSMALLINT* column_count_ptr); + /** * @brief Closes the cursor. This does _not_ un-prepare the statement or change * bindings. diff --git a/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt index 4bc240637e7..cf3e15451d9 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt @@ -35,6 +35,7 @@ add_arrow_test(flight_sql_odbc_test odbc_test_suite.cc odbc_test_suite.h connection_test.cc + statement_test.cc # Enable Protobuf cleanup after test execution # GH-46889: move protobuf_test_util to a more common location ../../../../engine/substrait/protobuf_test_util.cc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc new file mode 100644 index 00000000000..9172d471096 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + +#include +#include +#include + +#include + +#include +#include + +namespace arrow::flight::sql::odbc { + +template +class StatementTest : public T {}; + +class StatementMockTest : public FlightSQLODBCMockTestBase {}; +class StatementRemoteTest : public FlightSQLODBCRemoteTestBase {}; +using TestTypes = ::testing::Types; +TYPED_TEST_SUITE(StatementTest, TestTypes); + +TYPED_TEST(StatementTest, SQLNumResultColsReturnsColumnsOnSelect) { + SQLSMALLINT column_count = 0; + SQLSMALLINT expected_value = 3; + SQLWCHAR sql_query[] = L"SELECT 1 AS col1, 'One' AS col2, 3 AS col3"; + SQLINTEGER query_length = static_cast(wcslen(sql_query)); + + ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length)); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckIntColumn(this->stmt, 1, 1); + CheckStringColumnW(this->stmt, 2, L"One"); + CheckIntColumn(this->stmt, 3, 3); + + ASSERT_EQ(SQL_SUCCESS, SQLNumResultCols(this->stmt, &column_count)); + + EXPECT_EQ(expected_value, column_count); +} + +TYPED_TEST(StatementTest, SQLNumResultColsReturnsSuccessOnNullptr) { + SQLWCHAR sql_query[] = L"SELECT 1 AS col1, 'One' AS col2, 3 AS col3"; + SQLINTEGER query_length = static_cast(wcslen(sql_query)); + + ASSERT_EQ(SQL_SUCCESS, SQLExecDirect(this->stmt, sql_query, query_length)); + + ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt)); + + CheckIntColumn(this->stmt, 1, 1); + CheckStringColumnW(this->stmt, 2, L"One"); + CheckIntColumn(this->stmt, 3, 3); + + ASSERT_EQ(SQL_SUCCESS, SQLNumResultCols(this->stmt, nullptr)); +} + +TYPED_TEST(StatementTest, SQLNumResultColsFunctionSequenceErrorOnNoQuery) { + SQLSMALLINT column_count = 0; + SQLSMALLINT expected_value = 0; + + ASSERT_EQ(SQL_ERROR, SQLNumResultCols(this->stmt, &column_count)); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY010); + + EXPECT_EQ(expected_value, column_count); +} + +} // namespace arrow::flight::sql::odbc