5#if defined(_WIN32) || defined(_WIN64)
9#include <Lightweight/Lightweight.hpp>
11#include <catch2/catch_session.hpp>
12#include <catch2/catch_test_macros.hpp>
26#if __has_include(<stacktrace>)
36using WideChar = std::conditional_t<
sizeof(wchar_t) == 2,
wchar_t,
char16_t>;
37using WideString = std::basic_string<WideChar>;
38using WideStringView = std::basic_string_view<WideChar>;
41 #define WTEXT(x) (u##x)
43 #define WTEXT(x) (L##x)
46#define UNSUPPORTED_DATABASE(stmt, dbType) \
47 if ((stmt).Connection().ServerType() == (dbType)) \
49 WARN(std::format("TODO({}): This database is currently unsupported on this test.", dbType)); \
54struct std::formatter<std::u8string>: std::formatter<std::string>
56 auto format(std::u8string
const& value, std::format_context& ctx)
const -> std::format_context::iterator
58 return std::formatter<std::string>::format(
59 std::format(
"{}", (
char const*) value.c_str()),
70template <
typename W
ideStringT>
71 requires(detail::OneOf<WideStringT,
78ostream&
operator<<(ostream& os, WideStringT
const& str)
80 auto constexpr BitsPerChar =
sizeof(
typename WideStringT::value_type) * 8;
81 auto const u8String =
ToUtf8(str);
82 return os <<
"UTF-" << BitsPerChar <<
'{' <<
"length: " << str.size() <<
", characters: " <<
'"'
83 << string_view((
char const*) u8String.data(), u8String.size()) <<
'"' <<
'}';
86inline ostream& operator<<(ostream& os,
SqlGuid const& guid)
88 return os << format(
"SqlGuid({})", guid);
93template <std::
size_t Precision, std::
size_t Scale>
96 return os << std::format(
"SqlNumeric<{}, {}>({}, {}, {}, {})",
112 .value = std::format(
"DRIVER={};Database={}",
113#
if defined(_WIN32) || defined(_WIN64)
114 "SQLite3 ODBC Driver",
121class TestSuiteSqlLogger:
public SqlLogger::Null
124 std::string m_lastPreparedQuery;
126 template <
typename... Args>
127 void WriteInfo(std::format_string<Args...>
const& fmt, Args&&... args)
129 auto message = std::format(fmt, std::forward<Args>(args)...);
130 message = std::format(
"[{}] {}",
"Lightweight", message);
133 UNSCOPED_INFO(message);
137 std::println(
"{}", message);
141 template <
typename... Args>
142 void WriteWarning(std::format_string<Args...>
const& fmt, Args&&... args)
144 WARN(std::format(fmt, std::forward<Args>(args)...));
148 static TestSuiteSqlLogger& GetLogger() noexcept
150 static TestSuiteSqlLogger theLogger;
154 void OnError(SqlError error, std::source_location sourceLocation)
override
156 WriteWarning(
"SQL Error: {}", error);
157 WriteDetails(sourceLocation);
160 void OnError(
SqlErrorInfo const& errorInfo, std::source_location sourceLocation)
override
162 WriteWarning(
"SQL Error: {}", errorInfo);
163 WriteDetails(sourceLocation);
166 void OnWarning(std::string_view
const& message)
override
168 WriteWarning(
"{}", message);
169 WriteDetails(std::source_location::current());
172 void OnExecuteDirect(std::string_view
const& query)
override
174 WriteInfo(
"ExecuteDirect: {}", query);
177 void OnPrepare(std::string_view
const& query)
override
179 m_lastPreparedQuery = query;
182 void OnExecute(std::string_view
const& query)
override
184 WriteInfo(
"Execute: {}", query);
187 void OnExecuteBatch()
override
189 WriteInfo(
"ExecuteBatch: {}", m_lastPreparedQuery);
192 void OnFetchRow()
override
194 WriteInfo(
"Fetched row");
197 void OnFetchEnd()
override
199 WriteInfo(
"Fetch end");
203 void WriteDetails(std::source_location sourceLocation)
205 WriteInfo(
" Source: {}:{}", sourceLocation.file_name(), sourceLocation.line());
206 if (!m_lastPreparedQuery.empty())
207 WriteInfo(
" Query: {}", m_lastPreparedQuery);
208 WriteInfo(
" Stack trace:");
210#if __has_include(<stacktrace>)
211 auto stackTrace = std::stacktrace::current(1, 25);
212 for (std::size_t
const i: std::views::iota(std::size_t(0), stackTrace.size()))
213 WriteInfo(
" [{:>2}] {}", i, stackTrace[i]);
219class ScopedSqlNullLogger:
public SqlLogger::Null
225 ScopedSqlNullLogger()
230 ~ScopedSqlNullLogger()
override
236template <
typename Getter,
typename Callable>
237constexpr void FixedPointIterate(Getter
const& getter, Callable
const& callable)
254 static inline std::string testDatabaseName =
"LightweightTest";
255 static inline bool odbcTrace =
false;
257 using MainProgramArgs = std::tuple<int, char**>;
259 static std::variant<MainProgramArgs, int> Initialize(
int argc,
char** argv)
263 using namespace std::string_view_literals;
265 for (; i < argc; ++i)
267 if (argv[i] ==
"--trace-sql"sv)
269 else if (argv[i] ==
"--trace-odbc"sv)
271 else if (argv[i] ==
"--help"sv || argv[i] ==
"-h"sv)
273 std::println(
"{} [--trace-sql] [--trace-odbc] [[--] [Catch2 flags ...]]", argv[0]);
274 return { EXIT_SUCCESS };
276 else if (argv[i] ==
"--"sv)
286 argv[i - 1] = argv[0];
289 char* envBuffer =
nullptr;
290 size_t envBufferLen = 0;
291 _dupenv_s(&envBuffer, &envBufferLen,
"ODBC_CONNECTION_STRING");
292 if (
auto const* s = envBuffer; s && *s)
294 if (
auto const* s = std::getenv(
"ODBC_CONNECTION_STRING"); s && *s)
298 std::println(
"Using ODBC connection string: '{}'", SqlConnectionString::SanitizePwd(s));
304 std::println(
"Using default ODBC connection string: '{}'", DefaultTestConnectionString.value);
311 if (!sqlConnection.IsAlive())
313 std::println(
"Failed to connect to the database: {}", sqlConnection.LastError());
317 std::println(
"Running test cases against: {} ({}) (identified as: {})",
318 sqlConnection.ServerName(),
319 sqlConnection.ServerVersion(),
320 sqlConnection.ServerType());
322 return MainProgramArgs { argc - (i - 1), argv + (i - 1) };
329 auto const traceFile = []() -> std::string_view {
330#if !defined(_WIN32) && !defined(_WIN64)
331 return "/dev/stdout";
338 SQLSetConnectAttrA(handle, SQL_ATTR_TRACEFILE, (SQLPOINTER) traceFile.data(), SQL_NTS);
339 SQLSetConnectAttrA(handle, SQL_ATTR_TRACE, (SQLPOINTER) SQL_OPT_TRACE_ON, SQL_IS_UINTEGER);
344 case SqlServerType::SQLITE: {
347 stmt.ExecuteDirect(
"PRAGMA foreign_keys = ON");
350 case SqlServerType::MICROSOFT_SQL:
351 case SqlServerType::POSTGRESQL:
352 case SqlServerType::ORACLE:
353 case SqlServerType::MYSQL:
354 case SqlServerType::UNKNOWN:
362 REQUIRE(stmt.IsAlive());
366 SQLSMALLINT dbNameLen {};
367 SQLGetInfo(stmt.Connection().NativeHandle(), SQL_DATABASE_NAME, dbName,
sizeof(dbName), &dbNameLen);
369 testDatabaseName = dbName;
370 else if (stmt.Connection().ServerType() == SqlServerType::ORACLE)
371 testDatabaseName =
"FREEPDB1";
373 DropAllTablesInDatabase(stmt);
376 virtual ~SqlTestFixture() =
default;
378 static std::string ToString(std::vector<std::string>
const& values, std::string_view separator)
380 auto result = std::string {};
381 for (
auto const& value: values)
390 static void DropTableRecursively(
SqlStatement& stmt, SqlSchema::FullyQualifiedTableName
const& table)
392 auto const dependantTables = SqlSchema::AllForeignKeysTo(stmt, table);
393 for (
auto const& dependantTable: dependantTables)
394 DropTableRecursively(stmt, dependantTable.foreignKey.table);
395 stmt.
ExecuteDirect(std::format(
"DROP TABLE IF EXISTS \"{}\"", table.table));
402 case SqlServerType::MICROSOFT_SQL:
403 case SqlServerType::MYSQL:
404 stmt.
ExecuteDirect(std::format(
"USE \"{}\"", testDatabaseName));
406 case SqlServerType::SQLITE:
407 case SqlServerType::ORACLE:
408 case SqlServerType::UNKNOWN: {
409 auto const tableNames = GetAllTableNames(stmt);
410 for (
auto const& tableName: tableNames)
411 DropTableRecursively(stmt,
412 SqlSchema::FullyQualifiedTableName {
419 case SqlServerType::POSTGRESQL:
420 if (m_createdTables.empty())
421 m_createdTables = GetAllTableNames(stmt);
422 for (
auto& createdTable: std::views::reverse(m_createdTables))
423 stmt.ExecuteDirect(std::format(
"DROP TABLE IF EXISTS \"{}\" CASCADE", createdTable));
426 m_createdTables.clear();
430 static std::vector<std::string> GetAllTableNamesForOracle(
SqlStatement& stmt)
432 auto result = std::vector<std::string> {};
433 stmt.
Prepare(R
"SQL(SELECT table_name
435 WHERE table_name NOT LIKE '%$%'
436 AND table_name NOT IN ('SCHEDULER_JOB_ARGS_TBL', 'SCHEDULER_PROGRAM_ARGS_TBL', 'SQLPLUS_PRODUCT_PROFILE')
437 ORDER BY table_name)SQL");
441 result.emplace_back(stmt.
GetColumn<std::string>(1));
446 static std::vector<std::string> GetAllTableNames(
SqlStatement& stmt)
449 return GetAllTableNamesForOracle(stmt);
451 using namespace std::string_literals;
452 auto result = std::vector<std::string>();
453 auto const schemaName = [&] {
456 case SqlServerType::MICROSOFT_SQL:
463 (SQLCHAR*) testDatabaseName.data(),
464 (SQLSMALLINT) testDatabaseName.size(),
465 (SQLCHAR*) schemaName.data(),
466 (SQLSMALLINT) schemaName.size(),
471 if (SQL_SUCCEEDED(sqlResult))
475 result.emplace_back(stmt.
GetColumn<std::string>(3));
481 static inline std::vector<std::string> m_createdTables;
485inline std::ostream& operator<<(std::ostream& os,
SqlText const& value)
487 return os << std::format(
"SqlText({})", value.value);
490inline std::ostream& operator<<(std::ostream& os,
SqlDate const& date)
492 auto const ymd = date.
value();
493 return os << std::format(
"SqlDate {{ {}-{}-{} }}", ymd.year(), ymd.month(), ymd.day());
496inline std::ostream& operator<<(std::ostream& os,
SqlTime const& time)
498 auto const value = time.value();
499 return os << std::format(
"SqlTime {{ {:02}:{:02}:{:02}.{:06} }}",
500 value.hours().count(),
501 value.minutes().count(),
502 value.seconds().count(),
503 value.subseconds().count());
506inline std::ostream& operator<<(std::ostream& os,
SqlDateTime const& datetime)
508 auto const value = datetime.
value();
509 auto const totalDays = std::chrono::floor<std::chrono::days>(value);
510 auto const ymd = std::chrono::year_month_day { totalDays };
512 std::chrono::hh_mm_ss<std::chrono::nanoseconds> { std::chrono::floor<std::chrono::nanoseconds>(value - totalDays) };
513 return os << std::format(
"SqlDateTime {{ {:04}-{:02}-{:02} {:02}:{:02}:{:02}.{:09} }}",
515 (
unsigned) ymd.month(),
516 (
unsigned) ymd.day(),
518 hms.minutes().count(),
519 hms.seconds().count(),
520 hms.subseconds().count());
523template <std::
size_t N,
typename T, SqlFixedStringMode Mode>
526 if constexpr (Mode == SqlFixedStringMode::FIXED_SIZE)
527 return os << std::format(
"SqlFixedString<{}> {{ size: {}, data: '{}' }}", N, value.
size(), value.data());
528 else if constexpr (Mode == SqlFixedStringMode::FIXED_SIZE_RIGHT_TRIMMED)
529 return os << std::format(
"SqlTrimmedFixedString<{}> {{ '{}' }}", N, value.data());
530 else if constexpr (Mode == SqlFixedStringMode::VARIABLE_SIZE)
532 if constexpr (std::same_as<T, char>)
533 return os << std::format(
"SqlVariableString<{}> {{ size: {}, '{}' }}", N, value.
size(), value.data());
536 auto u8String =
ToUtf8(std::basic_string_view<T>(value.data(), value.
size()));
537 return os << std::format(
"SqlVariableString<{}, {}> {{ size: {}, '{}' }}",
539 Reflection::TypeNameOf<T>,
541 (
char const*) u8String.c_str());
545 return os << std::format(
"SqlFixedString<{}> {{ size: {}, data: '{}' }}", N, value.
size(), value.data());
548template <std::
size_t N,
typename T>
551 if constexpr (std::same_as<T, char>)
552 return os << std::format(
"SqlDynamicString<{}> {{ size: {}, '{}' }}", N, value.
size(), value.
data());
555 auto u8String =
ToUtf8(std::basic_string_view<T>(value.
data(), value.
size()));
556 return os << std::format(
"SqlDynamicString<{}, {}> {{ size: {}, '{}' }}",
558 Reflection::TypeNameOf<T>,
560 (
char const*) u8String.c_str());
564[[nodiscard]]
inline std::string NormalizeText(std::string_view
const& text)
566 auto result = std::string(text);
569 result.erase(std::unique(result.begin(),
571 [](
char a,
char b) { return std::isspace(a) && std::isspace(b); }),
575 while (!result.empty() && std::isspace(result.front()))
576 result.erase(result.begin());
578 while (!result.empty() && std::isspace(result.back()))
584[[nodiscard]]
inline std::string NormalizeText(std::vector<std::string>
const& texts)
586 auto result = std::string {};
587 for (
auto const& text: texts)
591 result += NormalizeText(text);
598inline void CreateEmployeesTable(
SqlStatement& stmt, std::source_location location = std::source_location::current())
603 .PrimaryKeyWithAutoIncrement(
"EmployeeID")
604 .RequiredColumn(
"FirstName", SqlColumnTypeDefinitions::Varchar { 50 })
605 .Column(
"LastName", SqlColumnTypeDefinitions::Varchar { 50 })
606 .RequiredColumn(
"Salary", SqlColumnTypeDefinitions::Integer {});
615 for (
char c =
'A'; c <=
'Z'; ++c)
617 table.
Column(std::string(1, c), SqlColumnTypeDefinitions::Varchar { 50 });
626 .Set(
"FirstName", SqlWildcard)
627 .Set(
"LastName", SqlWildcard)
628 .Set(
"Salary", SqlWildcard));
629 stmt.
Execute(
"Alice",
"Smith", 50'000);
630 stmt.
Execute(
"Bob",
"Johnson", 60'000);
631 stmt.
Execute(
"Charlie",
"Brown", 70'000);
634template <
typename T =
char>
635inline auto MakeLargeText(
size_t size)
637 auto text = std::basic_string<T>(size, {});
638 std::ranges::generate(text, [i = 0]()
mutable {
return static_cast<T
>(
'A' + (i++ % 26)); });
642inline bool IsGithubActions()
644#if defined(_WIN32) || defined(_WIN64)
645 char envBuffer[32] {};
646 size_t requiredCount = 0;
647 return getenv_s(&requiredCount, envBuffer,
sizeof(envBuffer),
"GITHUB_ACTIONS") == 0
648 && std::string_view(envBuffer) ==
"true" == 0;
650 return std::getenv(
"GITHUB_ACTIONS") !=
nullptr && std::string_view(std::getenv(
"GITHUB_ACTIONS")) ==
"true";
Represents a connection to a SQL database.
SqlServerType ServerType() const noexcept
Retrieves the type of the server.
static LIGHTWEIGHT_API void SetDefaultConnectionString(SqlConnectionString const &connectionString) noexcept
SQLHDBC NativeHandle() const noexcept
Retrieves the native handle.
static LIGHTWEIGHT_API void SetPostConnectedHook(std::function< void(SqlConnection &)> hook)
Sets a callback to be called after each connection being established.
LIGHTWEIGHT_API SqlCreateTableQueryBuilder & Column(SqlColumnDeclaration column)
Adds a new column to the table.
LIGHTWEIGHT_FORCE_INLINE T const * data() const noexcept
Retrieves the string's inner value (as T const*).
LIGHTWEIGHT_FORCE_INLINE std::size_t size() const noexcept
Retrieves the string's size.
LIGHTWEIGHT_FORCE_INLINE constexpr std::size_t size() const noexcept
Returns the size of the string.
Represents a logger for SQL operations.
static LIGHTWEIGHT_API void SetLogger(SqlLogger &logger)
static LIGHTWEIGHT_API SqlLogger & GetLogger()
Retrieves the currently configured logger.
static LIGHTWEIGHT_API SqlLogger & TraceLogger()
Retrieves a logger that logs to the trace logger.
Query builder for building SQL migration queries.
LIGHTWEIGHT_API SqlCreateTableQueryBuilder CreateTable(std::string_view tableName)
Creates a new table.
LIGHTWEIGHT_API SqlInsertQueryBuilder Insert(std::vector< SqlVariant > *boundInputs=nullptr) noexcept
High level API for (prepared) raw SQL statements.
LIGHTWEIGHT_API SqlConnection & Connection() noexcept
Retrieves the connection associated with this statement.
LIGHTWEIGHT_API SQLHSTMT NativeHandle() const noexcept
Retrieves the native handle of the statement.
void Execute(Args const &... args)
Binds the given arguments to the prepared statement and executes it.
LIGHTWEIGHT_API void ExecuteDirect(std::string_view const &query, std::source_location location=std::source_location::current())
Executes the given query directly.
void MigrateDirect(Callable const &callable, std::source_location location=std::source_location::current())
Executes an SQL migration query, as created b the callback.
LIGHTWEIGHT_API bool FetchRow()
LIGHTWEIGHT_API void Prepare(std::string_view query) &
bool GetColumn(SQLUSMALLINT column, T *result) const
LIGHTWEIGHT_API SqlQueryBuilder Query(std::string_view const &table={}) const
Creates a new query builder for the given table, compatible with the SQL server being connected.
LIGHTWEIGHT_API std::u8string ToUtf8(std::u32string_view u32InputString)
Represents an ODBC connection string.
constexpr LIGHTWEIGHT_FORCE_INLINE native_type value() const noexcept
Returns the current date and time.
LIGHTWEIGHT_FORCE_INLINE constexpr std::chrono::year_month_day value() const noexcept
Returns the current date.
Represents an ODBC SQL error.
constexpr LIGHTWEIGHT_FORCE_INLINE auto ToUnscaledValue() const noexcept
Converts the numeric to an unscaled integer value.
SQL_NUMERIC_STRUCT sqlValue
The value is stored as a string to avoid floating point precision issues.