5#if defined(_WIN32) || defined(_WIN64)
9#include <Lightweight/Lightweight.hpp>
11#include <catch2/catch_session.hpp>
12#include <catch2/catch_test_macros.hpp>
27#include <yaml-cpp/yaml.h>
29#if __has_include(<stacktrace>)
39using WideChar = std::conditional_t<
sizeof(wchar_t) == 2,
wchar_t,
char16_t>;
40using WideString = std::basic_string<WideChar>;
41using WideStringView = std::basic_string_view<WideChar>;
43#if defined(LIGHTWEIGHT_CXX26_REFLECTION)
55 #define WTEXT(x) (u##x)
57 #define WTEXT(x) (L##x)
60#define UNSUPPORTED_DATABASE(stmt, dbType) \
61 if ((stmt).Connection().ServerType() == (dbType)) \
63 WARN(std::format("TODO({}): This database is currently unsupported on this test.", dbType)); \
68struct std::formatter<std::u8string>: std::formatter<std::string>
70 auto format(std::u8string
const& value, std::format_context& ctx)
const -> std::format_context::iterator
72 return std::formatter<std::string>::format(
73 std::format(
"{}", (
char const*) value.c_str()),
84template <
typename W
ideStringT>
85 requires(Lightweight::detail::OneOf<WideStringT,
92ostream&
operator<<(ostream& os, WideStringT
const& str)
94 auto constexpr BitsPerChar =
sizeof(
typename WideStringT::value_type) * 8;
96 return os <<
"UTF-" << BitsPerChar <<
'{' <<
"length: " << str.size() <<
", characters: " <<
'"'
97 << string_view((
char const*) u8String.data(), u8String.size()) <<
'"' <<
'}';
102 return os << format(
"SqlGuid({})", guid);
105inline std::string EscapedBinaryText(std::string_view binary)
107 std::string hexEncodedString;
108 for (
auto const& b: binary)
111 hexEncodedString +=
static_cast<char>(b);
113 hexEncodedString += std::format(
"\\x{:02x}",
static_cast<unsigned char>(b));
115 return hexEncodedString;
121 auto const hexEncodedString = EscapedBinaryText(std::string_view((
char const*) binary.
data(), binary.
size()));
122 return os << std::format(
"SqlDynamicBinary<{}>(length: {}, characters: {})", N, binary.
size(), hexEncodedString);
129template <std::
size_t Precision, std::
size_t Scale>
132 return os << std::format(
"SqlNumeric<{}, {}>({}, {}, {}, {})",
136 value.sqlValue.precision,
137 value.sqlValue.scale,
152 .value = std::format(
"DRIVER={};Database={}",
153#
if defined(_WIN32) || defined(_WIN64)
154 "SQLite3 ODBC Driver",
161class TestSuiteSqlLogger:
public Lightweight::SqlLogger::Null
164 mutable std::mutex m_mutex;
165 std::string m_lastPreparedQuery;
167 void WriteRawInfo(std::string_view message);
169 template <
typename... Args>
170 void WriteInfo(std::format_string<Args...>
const& fmt, Args&&... args)
172 auto message = std::format(fmt, std::forward<Args>(args)...);
173 message = std::format(
"[{}] {}",
"Lightweight", message);
174 WriteRawInfo(message);
177 template <
typename... Args>
178 void WriteWarning(std::format_string<Args...>
const& fmt, Args&&... args)
180 WARN(std::format(fmt, std::forward<Args>(args)...));
184 static TestSuiteSqlLogger& GetLogger() noexcept
186 static TestSuiteSqlLogger theLogger;
190 void OnError(Lightweight::SqlError error, std::source_location sourceLocation)
override
192 WriteWarning(
"SQL Error: {}", error);
193 WriteDetails(sourceLocation);
198 WriteWarning(
"SQL Error: {}", errorInfo);
199 WriteDetails(sourceLocation);
202 void OnWarning(std::string_view
const& message)
override
204 WriteWarning(
"{}", message);
205 WriteDetails(std::source_location::current());
208 void OnExecuteDirect(std::string_view
const& query)
override
210 WriteInfo(
"ExecuteDirect: {}", query);
213 void OnPrepare(std::string_view
const& query)
override
215 std::scoped_lock lock(m_mutex);
216 m_lastPreparedQuery = query;
219 void OnExecute(std::string_view
const& query)
override
221 WriteInfo(
"Execute: {}", query);
224 void OnExecuteBatch()
override
226 std::scoped_lock lock(m_mutex);
227 WriteInfo(
"ExecuteBatch: {}", m_lastPreparedQuery);
230 void OnFetchRow()
override
232 WriteInfo(
"Fetched row");
235 void OnFetchEnd()
override
241 void WriteDetails(std::source_location sourceLocation)
243 std::scoped_lock lock(m_mutex);
244 WriteInfo(
" Source: {}:{}", sourceLocation.file_name(), sourceLocation.line());
245 if (!m_lastPreparedQuery.empty())
246 WriteInfo(
" Query: {}", m_lastPreparedQuery);
247 WriteInfo(
" Stack trace:");
249#if __has_include(<stacktrace>)
250 auto stackTrace = std::stacktrace::current(1, 25);
251 for (
auto const& entry: stackTrace)
252 WriteInfo(
" {}", entry);
258class ScopedSqlNullLogger:
public Lightweight::SqlLogger::Null
261 SqlLogger& m_previousLogger = SqlLogger::GetLogger();
264 ScopedSqlNullLogger()
266 SqlLogger::SetLogger(*
this);
269 ~ScopedSqlNullLogger()
override
271 SqlLogger::SetLogger(m_previousLogger);
275template <
typename Getter,
typename Callable>
276constexpr void FixedPointIterate(Getter
const& getter, Callable
const& callable)
292inline std::optional<std::filesystem::path> FindTestEnvFile()
294 auto currentDir = std::filesystem::current_path();
295 std::filesystem::path projectRoot;
298 auto searchDir = currentDir;
301 auto localEnvPath = searchDir /
".test-env.local.yml";
302 if (std::filesystem::exists(localEnvPath))
305 auto gitPath = searchDir /
".git";
306 if (std::filesystem::exists(gitPath))
308 projectRoot = searchDir;
312 auto parentDir = searchDir.parent_path();
313 if (parentDir == searchDir)
315 searchDir = parentDir;
319 if (!projectRoot.empty())
321 auto scriptsLocalEnvPath = projectRoot /
"scripts" /
"tests" /
".test-env.local.yml";
322 if (std::filesystem::exists(scriptsLocalEnvPath))
323 return scriptsLocalEnvPath;
327 searchDir = currentDir;
330 auto testEnvPath = searchDir /
".test-env.yml";
331 if (std::filesystem::exists(testEnvPath))
334 auto gitPath = searchDir /
".git";
335 if (std::filesystem::exists(gitPath))
337 projectRoot = searchDir;
341 auto parentDir = searchDir.parent_path();
342 if (parentDir == searchDir)
344 searchDir = parentDir;
348 auto scriptsTestEnvPath = projectRoot /
"scripts" /
"tests" /
".test-env.yml";
349 if (std::filesystem::exists(scriptsTestEnvPath))
350 return scriptsTestEnvPath;
359 static inline std::string testDatabaseName =
"LightweightTest";
360 static inline bool odbcTrace =
false;
361 static inline std::atomic<bool> running =
false;
363 using MainProgramArgs = std::tuple<int, char**>;
366 static std::variant<MainProgramArgs, int> Initialize(
int argc,
char** argv)
370 using namespace std::string_view_literals;
371 std::optional<std::string> testEnvName;
373 for (; i < argc; ++i)
375 if (argv[i] ==
"--trace-sql"sv)
377 else if (argv[i] ==
"--trace-odbc"sv)
379 else if (std::string_view(argv[i]).starts_with(
"--test-env="))
380 testEnvName = std::string_view(argv[i]).substr(11);
381 else if (argv[i] ==
"--help"sv || argv[i] ==
"-h"sv)
383 std::println(
"{} [--test-env=NAME] [--trace-sql] [--trace-odbc] [[--] [Catch2 flags ...]]", argv[0]);
385 std::println(
"Options:");
386 std::println(
" --test-env=NAME Use connection string from .test-env.yml (e.g., pgsql, mssql, sqlite)");
387 std::println(
" --trace-sql Enable SQL tracing");
388 std::println(
" --trace-odbc Enable ODBC tracing");
389 return { EXIT_SUCCESS };
391 else if (argv[i] ==
"--"sv)
401 argv[i - 1] = argv[0];
410 auto configPath = FindTestEnvFile();
414 "Error: .test-env.yml not found (searched from '{}' to project root)",
415 std::filesystem::current_path().
string());
416 return { EXIT_FAILURE };
421 YAML::Node config = YAML::LoadFile(configPath->string());
422 auto connectionStrings = config[
"ODBC_CONNECTION_STRING"];
423 if (!connectionStrings || !connectionStrings[*testEnvName])
425 std::println(stderr,
"Error: Key '{}' not found in ODBC_CONNECTION_STRING", *testEnvName);
426 if (connectionStrings && connectionStrings.IsMap())
428 std::print(stderr,
"Available environments:");
429 for (
auto const& entry: connectionStrings)
430 std::print(stderr,
" {}", entry.first.as<std::string>());
431 std::println(stderr,
"");
433 return { EXIT_FAILURE };
435 auto connStr = connectionStrings[*testEnvName].as<std::string>();
438 std::println(stderr,
"Error: Connection string for '{}' is empty", *testEnvName);
439 return { EXIT_FAILURE };
441 std::println(
"Using test environment '{}' from: {}", *testEnvName, configPath->string());
442 std::println(
"Using ODBC connection string: '{}'", Lightweight::SqlConnectionString::SanitizePwd(connStr));
445 catch (YAML::Exception
const& e)
447 std::println(stderr,
"Error parsing {}: {}", configPath->string(), e.what());
448 return { EXIT_FAILURE };
454 char* envBuffer =
nullptr;
455 size_t envBufferLen = 0;
456 _dupenv_s(&envBuffer, &envBufferLen,
"ODBC_CONNECTION_STRING");
457 if (
auto const* s = envBuffer; s && *s)
459 if (
auto const* s = std::getenv(
"ODBC_CONNECTION_STRING"); s && *s)
463 std::println(
"Using ODBC connection string: '{}'", Lightweight::SqlConnectionString::SanitizePwd(s));
469 std::println(
"Using default ODBC connection string: '{}'", DefaultTestConnectionString.value);
477 if (!sqlConnection.IsAlive())
479 std::println(
"Failed to connect to the database: {}", sqlConnection.LastError());
483 std::println(
"Running test cases against: {} ({}) (identified as: {})",
484 sqlConnection.ServerName(),
485 sqlConnection.ServerVersion(),
486 sqlConnection.ServerType());
488 return MainProgramArgs { argc - (i - 1), argv + (i - 1) };
495 auto const traceFile = []() -> std::string_view {
496#if !defined(_WIN32) && !defined(_WIN64)
497 return "/dev/stdout";
504 SQLSetConnectAttrA(handle, SQL_ATTR_TRACEFILE, (SQLPOINTER) traceFile.data(), SQL_NTS);
505 SQLSetConnectAttrA(handle, SQL_ATTR_TRACE, (SQLPOINTER) SQL_OPT_TRACE_ON, SQL_IS_UINTEGER);
508 using Lightweight::SqlServerType;
511 case SqlServerType::SQLITE: {
514 stmt.ExecuteDirect(
"PRAGMA foreign_keys = ON");
517 case SqlServerType::MICROSOFT_SQL:
518 case SqlServerType::POSTGRESQL:
519 case SqlServerType::MYSQL:
520 case SqlServerType::UNKNOWN:
529 REQUIRE(stmt.IsAlive());
532 SQLSMALLINT dbNameLen {};
533 SQLGetInfo(stmt.Connection().NativeHandle(), SQL_DATABASE_NAME, dbName,
sizeof(dbName), &dbNameLen);
535 testDatabaseName = dbName;
537 DropAllTablesInDatabase(stmt);
540 virtual ~SqlTestFixture()
546 static std::string ToString(std::vector<std::string>
const& values, std::string_view separator)
548 auto result = std::string {};
549 for (
auto const& value: values)
559 Lightweight::SqlSchema::FullyQualifiedTableName
const& table)
561 auto const dependantTables = Lightweight::SqlSchema::AllForeignKeysTo(stmt, table);
562 for (
auto const& dependantTable: dependantTables)
563 DropTableRecursively(stmt, dependantTable.foreignKey.table);
564 stmt.
ExecuteDirect(std::format(
"DROP TABLE IF EXISTS \"{}\"", table.table));
572 stmt.
ExecuteDirect(std::format(
"DROP TABLE IF EXISTS {}", tableName));
582 using Lightweight::SqlServerType;
585 case SqlServerType::MICROSOFT_SQL:
586 case SqlServerType::MYSQL:
587 stmt.
ExecuteDirect(std::format(
"USE \"{}\"", testDatabaseName));
589 case SqlServerType::SQLITE:
590 case SqlServerType::UNKNOWN: {
591 auto const tableNames = GetAllTableNames(stmt);
592 for (
auto const& tableName: tableNames)
594 if (tableName ==
"sqlite_sequence")
597 DropTableRecursively(stmt,
598 Lightweight::SqlSchema::FullyQualifiedTableName {
606 case SqlServerType::POSTGRESQL:
607 if (m_createdTables.empty())
608 m_createdTables = GetAllTableNames(stmt);
609 for (
auto& createdTable: std::views::reverse(m_createdTables))
610 stmt.ExecuteDirect(std::format(
"DROP TABLE IF EXISTS \"{}\" CASCADE", createdTable));
613 m_createdTables.clear();
618 using namespace std::string_literals;
621 case Lightweight::SqlServerType::MICROSOFT_SQL:
631 using namespace std::string_literals;
632 auto result = std::vector<std::string>();
633 auto const schemaName = GetDefaultSchemaName(stmt.
Connection());
635 (SQLCHAR*) testDatabaseName.data(),
636 (SQLSMALLINT) testDatabaseName.size(),
637 (SQLCHAR*) schemaName.data(),
638 (SQLSMALLINT) schemaName.size(),
643 if (SQL_SUCCEEDED(sqlResult))
647 result.emplace_back(stmt.
GetColumn<std::string>(3));
653 static inline std::vector<std::string> m_createdTables;
659 return os << std::format(
"SqlText({})", value.value);
664 auto const ymd = date.
value();
665 return os << std::format(
"SqlDate {{ {}-{}-{} }}", ymd.year(), ymd.month(), ymd.day());
670 auto const value = time.value();
671 return os << std::format(
"SqlTime {{ {:02}:{:02}:{:02}.{:06} }}",
672 value.hours().count(),
673 value.minutes().count(),
674 value.seconds().count(),
675 value.subseconds().count());
680 auto const value = datetime.
value();
681 auto const totalDays = std::chrono::floor<std::chrono::days>(value);
682 auto const ymd = std::chrono::year_month_day { totalDays };
684 std::chrono::hh_mm_ss<std::chrono::nanoseconds> { std::chrono::floor<std::chrono::nanoseconds>(value - totalDays) };
685 return os << std::format(
"SqlDateTime {{ {:04}-{:02}-{:02} {:02}:{:02}:{:02}.{:09} }}",
687 (
unsigned) ymd.month(),
688 (
unsigned) ymd.day(),
690 hms.minutes().count(),
691 hms.seconds().count(),
692 hms.subseconds().count());
695template <std::
size_t N,
typename T, Lightweight::SqlFixedStringMode Mode>
698 if constexpr (Mode == Lightweight::SqlFixedStringMode::FIXED_SIZE)
699 return os << std::format(
"SqlFixedString<{}> {{ size: {}, data: '{}' }}", N, value.
size(), value.data());
700 else if constexpr (Mode == Lightweight::SqlFixedStringMode::FIXED_SIZE_RIGHT_TRIMMED)
701 return os << std::format(
"SqlTrimmedFixedString<{}> {{ '{}' }}", N, value.data());
702 else if constexpr (Mode == Lightweight::SqlFixedStringMode::VARIABLE_SIZE)
704 if constexpr (std::same_as<T, char>)
705 return os << std::format(
"SqlVariableString<{}> {{ size: {}, '{}' }}", N, value.
size(), value.data());
708 auto u8String =
ToUtf8(std::basic_string_view<T>(value.data(), value.
size()));
709 return os << std::format(
"SqlVariableString<{}, {}> {{ size: {}, '{}' }}",
711 Reflection::TypeNameOf<T>,
713 (
char const*) u8String.c_str());
717 return os << std::format(
"SqlFixedString<{}> {{ size: {}, data: '{}' }}", N, value.
size(), value.data());
720template <std::
size_t N,
typename T>
723 if constexpr (std::same_as<T, char>)
724 return os << std::format(
"SqlDynamicString<{}> {{ size: {}, '{}' }}", N, value.
size(), value.
data());
727 auto u8String =
ToUtf8(std::basic_string_view<T>(value.
data(), value.
size()));
728 return os << std::format(
"SqlDynamicString<{}, {}> {{ size: {}, '{}' }}",
730 Reflection::TypeNameOf<T>,
732 (
char const*) u8String.c_str());
736[[nodiscard]]
inline std::string NormalizeText(std::string_view
const& text)
738 auto result = std::string(text);
741 result.erase(std::unique(result.begin(),
743 [](
char a,
char b) { return std::isspace(a) && std::isspace(b); }),
747 while (!result.empty() && std::isspace(result.front()))
748 result.erase(result.begin());
750 while (!result.empty() && std::isspace(result.back()))
756[[nodiscard]]
inline std::string NormalizeText(std::vector<std::string>
const& texts)
758 auto result = std::string {};
759 for (
auto const& text: texts)
763 result += NormalizeText(text);
773 if (conn.
ServerType() == Lightweight::SqlServerType::SQLITE)
783 if (conn.
ServerType() == Lightweight::SqlServerType::SQLITE)
792template <
typename Func>
797 stmt.
ExecuteDirect(std::format(
"SET IDENTITY_INSERT \"{}\" ON", tableName));
800 std::forward<Func>(func)();
801 stmt.
ExecuteDirect(std::format(
"SET IDENTITY_INSERT \"{}\" OFF", tableName));
805 stmt.
ExecuteDirect(std::format(
"SET IDENTITY_INSERT \"{}\" OFF", tableName));
811 std::forward<Func>(func)();
816 std::source_location location = std::source_location::current())
821 .PrimaryKeyWithAutoIncrement(
"EmployeeID")
822 .RequiredColumn(
"FirstName", Lightweight::SqlColumnTypeDefinitions::Varchar { 50 })
823 .Column(
"LastName", Lightweight::SqlColumnTypeDefinitions::Varchar { 50 })
824 .RequiredColumn(
"Salary", Lightweight::SqlColumnTypeDefinitions::Integer {});
833 for (
char c =
'A'; c <=
'Z'; ++c)
835 table.
Column(std::string(1, c), Lightweight::SqlColumnTypeDefinitions::Varchar { 50 });
844 .Set(
"FirstName", Lightweight::SqlWildcard)
845 .Set(
"LastName", Lightweight::SqlWildcard)
846 .Set(
"Salary", Lightweight::SqlWildcard));
847 stmt.
Execute(
"Alice",
"Smith", 50'000);
848 stmt.
Execute(
"Bob",
"Johnson", 60'000);
849 stmt.
Execute(
"Charlie",
"Brown", 70'000);
852template <
typename T =
char>
853inline auto MakeLargeText(
size_t size)
855 auto text = std::basic_string<T>(size, {});
856 std::ranges::generate(text, [i = 0]()
mutable {
return static_cast<T
>(
'A' + (i++ % 26)); });
860inline bool IsGithubActions()
862#if defined(_WIN32) || defined(_WIN64)
863 char envBuffer[32] {};
864 size_t requiredCount = 0;
865 return getenv_s(&requiredCount, envBuffer,
sizeof(envBuffer),
"GITHUB_ACTIONS") == 0
866 && std::string_view(envBuffer) ==
"true" == 0;
868 return std::getenv(
"GITHUB_ACTIONS") !=
nullptr
869 && std::string_view(std::getenv(
"GITHUB_ACTIONS")) ==
"true";
877ostream& operator<<(ostream& os, optional<T>
const& opt)
882 return os <<
"nullopt";
887template <
typename T, auto P1, auto P2>
888std::ostream& operator<<(std::ostream& os,
Lightweight::Field<std::optional<T>, P1, P2>
const& field)
891 return os << std::format(
"Field<{}> {{ {}, {} }}",
892 Reflection::TypeNameOf<T>,
894 field.IsModified() ?
"modified" :
"not modified");
899template <
typename T, auto P1, auto P2>
902 return os << std::format(
"Field<{}> {{ ", Reflection::TypeNameOf<T>) <<
"value: " << field.
Value() <<
"; "
903 << (field.
IsModified() ?
"modified" :
"not modified") <<
" }";
Represents a connection to a SQL database.
SqlServerType ServerType() const noexcept
Retrieves the type of the server.
static LIGHTWEIGHT_API void SetDefaultConnectionString(SqlConnectionString const &connectionString) noexcept
SQLHDBC NativeHandle() const noexcept
Retrieves the native handle.
static LIGHTWEIGHT_API void SetPostConnectedHook(std::function< void(SqlConnection &)> hook)
Sets a callback to be called after each connection being established.
LIGHTWEIGHT_API SqlCreateTableQueryBuilder & Column(SqlColumnDeclaration column)
Adds a new column to the table.
LIGHTWEIGHT_FORCE_INLINE constexpr decltype(auto) data(this auto &&self) noexcept
Retrieves the pointer to the string data.
LIGHTWEIGHT_FORCE_INLINE constexpr std::size_t size() const noexcept
Retrieves the size of the string.
LIGHTWEIGHT_FORCE_INLINE std::size_t size() const noexcept
Retrieves the string's size.
LIGHTWEIGHT_FORCE_INLINE T const * data() const noexcept
Retrieves the string's inner value (as T const*).
LIGHTWEIGHT_FORCE_INLINE constexpr std::size_t size() const noexcept
Returns the size of the string.
static LIGHTWEIGHT_API SqlLogger & TraceLogger()
Retrieves a logger that logs to the trace logger.
static LIGHTWEIGHT_API void SetLogger(SqlLogger &logger)
static LIGHTWEIGHT_API SqlLogger & StandardLogger()
Retrieves a logger that logs to standard output.
Query builder for building SQL migration queries.
LIGHTWEIGHT_API SqlCreateTableQueryBuilder CreateTable(std::string_view tableName)
Creates a new table.
LIGHTWEIGHT_API SqlInsertQueryBuilder Insert(std::vector< SqlVariant > *boundInputs=nullptr) noexcept
High level API for (prepared) raw SQL statements.
LIGHTWEIGHT_API void Prepare(std::string_view query) &
LIGHTWEIGHT_API SQLHSTMT NativeHandle() const noexcept
Retrieves the native handle of the statement.
void MigrateDirect(Callable const &callable, std::source_location location=std::source_location::current())
Executes an SQL migration query, as created b the callback.
LIGHTWEIGHT_API SqlConnection & Connection() noexcept
Retrieves the connection associated with this statement.
void Execute(Args const &... args)
Binds the given arguments to the prepared statement and executes it.
bool GetColumn(SQLUSMALLINT column, T *result) const
LIGHTWEIGHT_API SqlQueryBuilder Query(std::string_view const &table={}) const
Creates a new query builder for the given table, compatible with the SQL server being connected.
LIGHTWEIGHT_API void ExecuteDirect(std::string_view const &query, std::source_location location=std::source_location::current())
Executes the given query directly.
LIGHTWEIGHT_API bool FetchRow()
LIGHTWEIGHT_API std::u8string ToUtf8(std::u32string_view u32InputString)
Represents a single column in a table.
constexpr bool IsModified() const noexcept
Checks if the field has been modified.
constexpr T const & Value() const noexcept
Returns the value of the field.
Represents an ODBC connection string.
constexpr LIGHTWEIGHT_FORCE_INLINE native_type value() const noexcept
Returns the current date and time.
LIGHTWEIGHT_FORCE_INLINE constexpr std::chrono::year_month_day value() const noexcept
Returns the current date.
Represents an ODBC SQL error.
constexpr LIGHTWEIGHT_FORCE_INLINE auto ToUnscaledValue() const noexcept
Converts the numeric to an unscaled integer value.