Lightweight 0.20251202.0
Loading...
Searching...
No Matches
Utils.hpp
1// SPDX-License-Identifier: Apache-2.0
2
3#pragma once
4
5#if defined(_WIN32) || defined(_WIN64)
6 #include <Windows.h>
7#endif
8
9#include <Lightweight/Lightweight.hpp>
10
11#include <catch2/catch_session.hpp>
12#include <catch2/catch_test_macros.hpp>
13
14#include <algorithm>
15#include <chrono>
16#include <filesystem>
17#include <format>
18#include <mutex>
19#include <ostream>
20#include <ranges>
21#include <string>
22#include <string_view>
23#include <tuple>
24#include <variant>
25#include <vector>
26
27#include <yaml-cpp/yaml.h>
28
29#if __has_include(<stacktrace>)
30 #include <stacktrace>
31#endif
32
33#include <sql.h>
34#include <sqlext.h>
35#include <sqlspi.h>
36#include <sqltypes.h>
37
38// NOTE: I've done this preprocessor stuff only to have a single test for UTF-16 (UCS-2) regardless of platform.
39using WideChar = std::conditional_t<sizeof(wchar_t) == 2, wchar_t, char16_t>;
40using WideString = std::basic_string<WideChar>;
41using WideStringView = std::basic_string_view<WideChar>;
42
43#if defined(LIGHTWEIGHT_CXX26_REFLECTION)
44 /// @brief marco to define a member to the structure, in case of C++26 reflection this
45 /// will create reflection, in case of C++20 reflection this will create a member pointer
46 #define Member(x) ^^x
47
48#else
49 /// @brief marco to define a member to the structure, in case of C++26 reflection this
50 /// will create reflection, in case of C++20 reflection this will create a member pointer
51 #define Member(x) &x
52#endif
53
54#if !defined(_WIN32)
55 #define WTEXT(x) (u##x)
56#else
57 #define WTEXT(x) (L##x)
58#endif
59
60#define UNSUPPORTED_DATABASE(stmt, dbType) \
61 if ((stmt).Connection().ServerType() == (dbType)) \
62 { \
63 WARN(std::format("TODO({}): This database is currently unsupported on this test.", dbType)); \
64 return; \
65 }
66
67template <>
68struct std::formatter<std::u8string>: std::formatter<std::string>
69{
70 auto format(std::u8string const& value, std::format_context& ctx) const -> std::format_context::iterator
71 {
72 return std::formatter<std::string>::format(
73 std::format("{}", (char const*) value.c_str()), // NOLINT(readability-redundant-string-cstr)
74 ctx);
75 }
76};
77
78namespace std
79{
80
81// Add support for std::basic_string<WideChar> and std::basic_string_view<WideChar> to std::ostream,
82// so that we can get them pretty-printed in REQUIRE() and CHECK() macros.
83
84template <typename WideStringT>
85 requires(Lightweight::detail::OneOf<WideStringT,
86 std::wstring,
87 std::wstring_view,
88 std::u16string,
89 std::u16string_view,
90 std::u32string,
91 std::u32string_view>)
92ostream& operator<<(ostream& os, WideStringT const& str)
93{
94 auto constexpr BitsPerChar = sizeof(typename WideStringT::value_type) * 8;
95 auto const u8String = Lightweight::ToUtf8(str);
96 return os << "UTF-" << BitsPerChar << '{' << "length: " << str.size() << ", characters: " << '"'
97 << string_view((char const*) u8String.data(), u8String.size()) << '"' << '}';
98}
99
100inline ostream& operator<<(ostream& os, Lightweight::SqlGuid const& guid)
101{
102 return os << format("SqlGuid({})", guid);
103}
104
105inline std::string EscapedBinaryText(std::string_view binary)
106{
107 std::string hexEncodedString;
108 for (auto const& b: binary)
109 {
110 if (std::isprint(b))
111 hexEncodedString += static_cast<char>(b);
112 else
113 hexEncodedString += std::format("\\x{:02x}", static_cast<unsigned char>(b));
114 }
115 return hexEncodedString;
116}
117
118template <size_t N>
119inline ostream& operator<<(ostream& os, Lightweight::SqlDynamicBinary<N> const& binary)
120{
121 auto const hexEncodedString = EscapedBinaryText(std::string_view((char const*) binary.data(), binary.size()));
122 return os << std::format("SqlDynamicBinary<{}>(length: {}, characters: {})", N, binary.size(), hexEncodedString);
123}
124
125} // namespace std
126
127namespace std
128{
129template <std::size_t Precision, std::size_t Scale>
130std::ostream& operator<<(std::ostream& os, Lightweight::SqlNumeric<Precision, Scale> const& value)
131{
132 return os << std::format("SqlNumeric<{}, {}>({}, {}, {}, {})",
133 Precision,
134 Scale,
135 value.sqlValue.sign,
136 value.sqlValue.precision,
137 value.sqlValue.scale,
138 value.ToUnscaledValue());
139}
140} // namespace std
141
142// Refer to an in-memory SQLite database (and assuming the sqliteodbc driver is installed)
143// See:
144// - https://www.sqlite.org/inmemorydb.html
145// - http://www.ch-werner.de/sqliteodbc/
146// - https://github.com/softace/sqliteodbc
147//
148
149// clang-format off
150auto inline const DefaultTestConnectionString = Lightweight::SqlConnectionString { //NOLINT(bugprone-throwing-static-initialization)
151 // clang-format on
152 .value = std::format("DRIVER={};Database={}",
153#if defined(_WIN32) || defined(_WIN64)
154 "SQLite3 ODBC Driver",
155#else
156 "SQLite3",
157#endif
158 "test.db"),
159};
160
161class TestSuiteSqlLogger: public Lightweight::SqlLogger::Null
162{
163 private:
164 mutable std::mutex m_mutex;
165 std::string m_lastPreparedQuery;
166
167 void WriteRawInfo(std::string_view message);
168
169 template <typename... Args>
170 void WriteInfo(std::format_string<Args...> const& fmt, Args&&... args)
171 {
172 auto message = std::format(fmt, std::forward<Args>(args)...);
173 message = std::format("[{}] {}", "Lightweight", message);
174 WriteRawInfo(message);
175 }
176
177 template <typename... Args>
178 void WriteWarning(std::format_string<Args...> const& fmt, Args&&... args)
179 {
180 WARN(std::format(fmt, std::forward<Args>(args)...));
181 }
182
183 public:
184 static TestSuiteSqlLogger& GetLogger() noexcept
185 {
186 static TestSuiteSqlLogger theLogger;
187 return theLogger;
188 }
189
190 void OnError(Lightweight::SqlError error, std::source_location sourceLocation) override
191 {
192 WriteWarning("SQL Error: {}", error);
193 WriteDetails(sourceLocation);
194 }
195
196 void OnError(Lightweight::SqlErrorInfo const& errorInfo, std::source_location sourceLocation) override
197 {
198 WriteWarning("SQL Error: {}", errorInfo);
199 WriteDetails(sourceLocation);
200 }
201
202 void OnWarning(std::string_view const& message) override
203 {
204 WriteWarning("{}", message);
205 WriteDetails(std::source_location::current());
206 }
207
208 void OnExecuteDirect(std::string_view const& query) override
209 {
210 WriteInfo("ExecuteDirect: {}", query);
211 }
212
213 void OnPrepare(std::string_view const& query) override
214 {
215 std::scoped_lock lock(m_mutex);
216 m_lastPreparedQuery = query;
217 }
218
219 void OnExecute(std::string_view const& query) override
220 {
221 WriteInfo("Execute: {}", query);
222 }
223
224 void OnExecuteBatch() override
225 {
226 std::scoped_lock lock(m_mutex);
227 WriteInfo("ExecuteBatch: {}", m_lastPreparedQuery);
228 }
229
230 void OnFetchRow() override
231 {
232 WriteInfo("Fetched row");
233 }
234
235 void OnFetchEnd() override
236 {
237 // WriteInfo("Fetch end");
238 }
239
240 private:
241 void WriteDetails(std::source_location sourceLocation)
242 {
243 std::scoped_lock lock(m_mutex);
244 WriteInfo(" Source: {}:{}", sourceLocation.file_name(), sourceLocation.line());
245 if (!m_lastPreparedQuery.empty())
246 WriteInfo(" Query: {}", m_lastPreparedQuery);
247 WriteInfo(" Stack trace:");
248
249#if __has_include(<stacktrace>)
250 auto stackTrace = std::stacktrace::current(1, 25);
251 for (auto const& entry: stackTrace)
252 WriteInfo(" {}", entry);
253#endif
254 }
255};
256
257// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
258class ScopedSqlNullLogger: public Lightweight::SqlLogger::Null
259{
260 private:
261 SqlLogger& m_previousLogger = SqlLogger::GetLogger();
262
263 public:
264 ScopedSqlNullLogger()
265 {
266 SqlLogger::SetLogger(*this);
267 }
268
269 ~ScopedSqlNullLogger() override
270 {
271 SqlLogger::SetLogger(m_previousLogger);
272 }
273};
274
275template <typename Getter, typename Callable>
276constexpr void FixedPointIterate(Getter const& getter, Callable const& callable)
277{
278 auto a = getter();
279 for (;;)
280 {
281 callable(a);
282 auto b = getter();
283 if (a == b)
284 break;
285 a = std::move(b);
286 }
287}
288
289// Searches upward from current working directory for .test-env.local.yml then .test-env.yml
290// Stops at directory containing .git (project root boundary) or filesystem root
291// Also checks scripts/tests/ at project root
292inline std::optional<std::filesystem::path> FindTestEnvFile()
293{
294 auto currentDir = std::filesystem::current_path();
295 std::filesystem::path projectRoot;
296
297 // First pass: search for .test-env.local.yml (local overrides)
298 auto searchDir = currentDir;
299 while (true)
300 {
301 auto localEnvPath = searchDir / ".test-env.local.yml";
302 if (std::filesystem::exists(localEnvPath))
303 return localEnvPath;
304
305 auto gitPath = searchDir / ".git";
306 if (std::filesystem::exists(gitPath))
307 {
308 projectRoot = searchDir;
309 break;
310 }
311
312 auto parentDir = searchDir.parent_path();
313 if (parentDir == searchDir)
314 break;
315 searchDir = parentDir;
316 }
317
318 // Check scripts/tests/.test-env.local.yml at project root
319 if (!projectRoot.empty())
320 {
321 auto scriptsLocalEnvPath = projectRoot / "scripts" / "tests" / ".test-env.local.yml";
322 if (std::filesystem::exists(scriptsLocalEnvPath))
323 return scriptsLocalEnvPath;
324 }
325
326 // Second pass: search for .test-env.yml (default config)
327 searchDir = currentDir;
328 while (true)
329 {
330 auto testEnvPath = searchDir / ".test-env.yml";
331 if (std::filesystem::exists(testEnvPath))
332 return testEnvPath;
333
334 auto gitPath = searchDir / ".git";
335 if (std::filesystem::exists(gitPath))
336 {
337 projectRoot = searchDir;
338 break;
339 }
340
341 auto parentDir = searchDir.parent_path();
342 if (parentDir == searchDir)
343 return std::nullopt;
344 searchDir = parentDir;
345 }
346
347 // Check scripts/tests/.test-env.yml at project root
348 auto scriptsTestEnvPath = projectRoot / "scripts" / "tests" / ".test-env.yml";
349 if (std::filesystem::exists(scriptsTestEnvPath))
350 return scriptsTestEnvPath;
351
352 return std::nullopt;
353}
354
355// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
356class SqlTestFixture
357{
358 public:
359 static inline std::string testDatabaseName = "LightweightTest"; // NOLINT(bugprone-throwing-static-initialization)
360 static inline bool odbcTrace = false;
361 static inline std::atomic<bool> running = false;
362
363 using MainProgramArgs = std::tuple<int, char**>;
364
365 // NOLINTNEXTLINE(readability-function-cognitive-complexity)
366 static std::variant<MainProgramArgs, int> Initialize(int argc, char** argv)
367 {
368 Lightweight::SqlLogger::SetLogger(TestSuiteSqlLogger::GetLogger());
369
370 using namespace std::string_view_literals;
371 std::optional<std::string> testEnvName;
372 int i = 1;
373 for (; i < argc; ++i)
374 {
375 if (argv[i] == "--trace-sql"sv)
377 else if (argv[i] == "--trace-odbc"sv)
378 odbcTrace = true;
379 else if (std::string_view(argv[i]).starts_with("--test-env="))
380 testEnvName = std::string_view(argv[i]).substr(11);
381 else if (argv[i] == "--help"sv || argv[i] == "-h"sv)
382 {
383 std::println("{} [--test-env=NAME] [--trace-sql] [--trace-odbc] [[--] [Catch2 flags ...]]", argv[0]);
384 std::println("");
385 std::println("Options:");
386 std::println(" --test-env=NAME Use connection string from .test-env.yml (e.g., pgsql, mssql, sqlite)");
387 std::println(" --trace-sql Enable SQL tracing");
388 std::println(" --trace-odbc Enable ODBC tracing");
389 return { EXIT_SUCCESS };
390 }
391 else if (argv[i] == "--"sv)
392 {
393 ++i;
394 break;
395 }
396 else
397 break;
398 }
399
400 if (i < argc)
401 argv[i - 1] = argv[0];
402
403 // Connection string priority:
404 // 1. --test-env=NAME (must find valid entry, error if not found)
405 // 2. Environment variable ODBC_CONNECTION_STRING
406 // 3. Default SQLite3
407
408 if (testEnvName)
409 {
410 auto configPath = FindTestEnvFile();
411 if (!configPath)
412 {
413 std::println(stderr,
414 "Error: .test-env.yml not found (searched from '{}' to project root)",
415 std::filesystem::current_path().string());
416 return { EXIT_FAILURE };
417 }
418
419 try
420 {
421 YAML::Node config = YAML::LoadFile(configPath->string());
422 auto connectionStrings = config["ODBC_CONNECTION_STRING"];
423 if (!connectionStrings || !connectionStrings[*testEnvName])
424 {
425 std::println(stderr, "Error: Key '{}' not found in ODBC_CONNECTION_STRING", *testEnvName);
426 if (connectionStrings && connectionStrings.IsMap())
427 {
428 std::print(stderr, "Available environments:");
429 for (auto const& entry: connectionStrings)
430 std::print(stderr, " {}", entry.first.as<std::string>());
431 std::println(stderr, "");
432 }
433 return { EXIT_FAILURE };
434 }
435 auto connStr = connectionStrings[*testEnvName].as<std::string>();
436 if (connStr.empty())
437 {
438 std::println(stderr, "Error: Connection string for '{}' is empty", *testEnvName);
439 return { EXIT_FAILURE };
440 }
441 std::println("Using test environment '{}' from: {}", *testEnvName, configPath->string());
442 std::println("Using ODBC connection string: '{}'", Lightweight::SqlConnectionString::SanitizePwd(connStr));
444 }
445 catch (YAML::Exception const& e)
446 {
447 std::println(stderr, "Error parsing {}: {}", configPath->string(), e.what());
448 return { EXIT_FAILURE };
449 }
450 }
451 else
452 {
453#if defined(_MSC_VER)
454 char* envBuffer = nullptr;
455 size_t envBufferLen = 0;
456 _dupenv_s(&envBuffer, &envBufferLen, "ODBC_CONNECTION_STRING");
457 if (auto const* s = envBuffer; s && *s)
458#else
459 if (auto const* s = std::getenv("ODBC_CONNECTION_STRING"); s && *s)
460#endif
461
462 {
463 std::println("Using ODBC connection string: '{}'", Lightweight::SqlConnectionString::SanitizePwd(s));
465 }
466 else
467 {
468 // Use an in-memory SQLite3 database by default (for testing purposes)
469 std::println("Using default ODBC connection string: '{}'", DefaultTestConnectionString.value);
471 }
472 }
473
474 Lightweight::SqlConnection::SetPostConnectedHook(&SqlTestFixture::PostConnectedHook);
475
476 auto sqlConnection = Lightweight::SqlConnection();
477 if (!sqlConnection.IsAlive())
478 {
479 std::println("Failed to connect to the database: {}", sqlConnection.LastError());
480 std::abort();
481 }
482
483 std::println("Running test cases against: {} ({}) (identified as: {})",
484 sqlConnection.ServerName(),
485 sqlConnection.ServerVersion(),
486 sqlConnection.ServerType());
487
488 return MainProgramArgs { argc - (i - 1), argv + (i - 1) };
489 }
490
491 static void PostConnectedHook(Lightweight::SqlConnection& connection)
492 {
493 if (odbcTrace)
494 {
495 auto const traceFile = []() -> std::string_view {
496#if !defined(_WIN32) && !defined(_WIN64)
497 return "/dev/stdout";
498#else
499 return "CONOUT$";
500#endif
501 }();
502
503 SQLHDBC handle = connection.NativeHandle();
504 SQLSetConnectAttrA(handle, SQL_ATTR_TRACEFILE, (SQLPOINTER) traceFile.data(), SQL_NTS);
505 SQLSetConnectAttrA(handle, SQL_ATTR_TRACE, (SQLPOINTER) SQL_OPT_TRACE_ON, SQL_IS_UINTEGER);
506 }
507
508 using Lightweight::SqlServerType;
509 switch (connection.ServerType())
510 {
511 case SqlServerType::SQLITE: {
512 auto stmt = Lightweight::SqlStatement { connection };
513 // Enable foreign key constraints for SQLite
514 stmt.ExecuteDirect("PRAGMA foreign_keys = ON");
515 break;
516 }
517 case SqlServerType::MICROSOFT_SQL:
518 case SqlServerType::POSTGRESQL:
519 case SqlServerType::MYSQL:
520 case SqlServerType::UNKNOWN:
521 break;
522 }
523 }
524
525 SqlTestFixture()
526 {
527 running = true;
528 auto stmt = Lightweight::SqlStatement();
529 REQUIRE(stmt.IsAlive());
530
531 char dbName[100]; // Buffer to store the database name
532 SQLSMALLINT dbNameLen {};
533 SQLGetInfo(stmt.Connection().NativeHandle(), SQL_DATABASE_NAME, dbName, sizeof(dbName), &dbNameLen);
534 if (dbNameLen > 0)
535 testDatabaseName = dbName;
536
537 DropAllTablesInDatabase(stmt);
538 }
539
540 virtual ~SqlTestFixture()
541 {
542 running = false;
544 }
545
546 static std::string ToString(std::vector<std::string> const& values, std::string_view separator)
547 {
548 auto result = std::string {};
549 for (auto const& value: values)
550 {
551 if (!result.empty())
552 result += separator;
553 result += value;
554 }
555 return result;
556 }
557
558 static void DropTableRecursively(Lightweight::SqlStatement& stmt,
559 Lightweight::SqlSchema::FullyQualifiedTableName const& table)
560 {
561 auto const dependantTables = Lightweight::SqlSchema::AllForeignKeysTo(stmt, table);
562 for (auto const& dependantTable: dependantTables)
563 DropTableRecursively(stmt, dependantTable.foreignKey.table);
564 stmt.ExecuteDirect(std::format("DROP TABLE IF EXISTS \"{}\"", table.table));
565 }
566
567 static void DropTableIfExists(Lightweight::SqlConnection& conn, std::string const& tableName)
568 {
569 Lightweight::SqlStatement stmt { conn };
570 try
571 {
572 stmt.ExecuteDirect(std::format("DROP TABLE IF EXISTS {}", tableName));
573 }
574 catch (...)
575 {
576 ; // ignore
577 }
578 }
579
580 static void DropAllTablesInDatabase(Lightweight::SqlStatement& stmt)
581 {
582 using Lightweight::SqlServerType;
583 switch (stmt.Connection().ServerType())
584 {
585 case SqlServerType::MICROSOFT_SQL:
586 case SqlServerType::MYSQL:
587 stmt.ExecuteDirect(std::format("USE \"{}\"", testDatabaseName));
588 [[fallthrough]];
589 case SqlServerType::SQLITE:
590 case SqlServerType::UNKNOWN: {
591 auto const tableNames = GetAllTableNames(stmt);
592 for (auto const& tableName: tableNames)
593 {
594 if (tableName == "sqlite_sequence")
595 continue;
596
597 DropTableRecursively(stmt,
598 Lightweight::SqlSchema::FullyQualifiedTableName {
599 .catalog = {},
600 .schema = {},
601 .table = tableName,
602 });
603 }
604 break;
605 }
606 case SqlServerType::POSTGRESQL:
607 if (m_createdTables.empty())
608 m_createdTables = GetAllTableNames(stmt);
609 for (auto& createdTable: std::views::reverse(m_createdTables))
610 stmt.ExecuteDirect(std::format("DROP TABLE IF EXISTS \"{}\" CASCADE", createdTable));
611 break;
612 }
613 m_createdTables.clear();
614 }
615
616 static std::string GetDefaultSchemaName(Lightweight::SqlConnection const& connection)
617 {
618 using namespace std::string_literals;
619 switch (connection.ServerType())
620 {
621 case Lightweight::SqlServerType::MICROSOFT_SQL:
622 return "dbo"s;
623 default:
624 return ""s;
625 }
626 }
627
628 private:
629 static std::vector<std::string> GetAllTableNames(Lightweight::SqlStatement& stmt)
630 {
631 using namespace std::string_literals;
632 auto result = std::vector<std::string>();
633 auto const schemaName = GetDefaultSchemaName(stmt.Connection());
634 auto const sqlResult = SQLTables(stmt.NativeHandle(),
635 (SQLCHAR*) testDatabaseName.data(),
636 (SQLSMALLINT) testDatabaseName.size(),
637 (SQLCHAR*) schemaName.data(),
638 (SQLSMALLINT) schemaName.size(),
639 nullptr,
640 0,
641 (SQLCHAR*) "TABLE",
642 SQL_NTS);
643 if (SQL_SUCCEEDED(sqlResult))
644 {
645 while (stmt.FetchRow())
646 {
647 result.emplace_back(stmt.GetColumn<std::string>(3)); // table name
648 }
649 }
650 return result;
651 }
652
653 static inline std::vector<std::string> m_createdTables;
654};
655
656// {{{ ostream support for Lightweight, for debugging purposes
657inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlText const& value)
658{
659 return os << std::format("SqlText({})", value.value);
660}
661
662inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlDate const& date)
663{
664 auto const ymd = date.value();
665 return os << std::format("SqlDate {{ {}-{}-{} }}", ymd.year(), ymd.month(), ymd.day());
666}
667
668inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlTime const& time)
669{
670 auto const value = time.value();
671 return os << std::format("SqlTime {{ {:02}:{:02}:{:02}.{:06} }}",
672 value.hours().count(),
673 value.minutes().count(),
674 value.seconds().count(),
675 value.subseconds().count());
676}
677
678inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlDateTime const& datetime)
679{
680 auto const value = datetime.value();
681 auto const totalDays = std::chrono::floor<std::chrono::days>(value);
682 auto const ymd = std::chrono::year_month_day { totalDays };
683 auto const hms =
684 std::chrono::hh_mm_ss<std::chrono::nanoseconds> { std::chrono::floor<std::chrono::nanoseconds>(value - totalDays) };
685 return os << std::format("SqlDateTime {{ {:04}-{:02}-{:02} {:02}:{:02}:{:02}.{:09} }}",
686 (int) ymd.year(),
687 (unsigned) ymd.month(),
688 (unsigned) ymd.day(),
689 hms.hours().count(),
690 hms.minutes().count(),
691 hms.seconds().count(),
692 hms.subseconds().count());
693}
694
695template <std::size_t N, typename T, Lightweight::SqlFixedStringMode Mode>
696inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlFixedString<N, T, Mode> const& value)
697{
698 if constexpr (Mode == Lightweight::SqlFixedStringMode::FIXED_SIZE)
699 return os << std::format("SqlFixedString<{}> {{ size: {}, data: '{}' }}", N, value.size(), value.data());
700 else if constexpr (Mode == Lightweight::SqlFixedStringMode::FIXED_SIZE_RIGHT_TRIMMED)
701 return os << std::format("SqlTrimmedFixedString<{}> {{ '{}' }}", N, value.data());
702 else if constexpr (Mode == Lightweight::SqlFixedStringMode::VARIABLE_SIZE)
703 {
704 if constexpr (std::same_as<T, char>)
705 return os << std::format("SqlVariableString<{}> {{ size: {}, '{}' }}", N, value.size(), value.data());
706 else
707 {
708 auto u8String = ToUtf8(std::basic_string_view<T>(value.data(), value.size()));
709 return os << std::format("SqlVariableString<{}, {}> {{ size: {}, '{}' }}",
710 N,
711 Reflection::TypeNameOf<T>,
712 value.size(),
713 (char const*) u8String.c_str()); // NOLINT(readability-redundant-string-cstr)
714 }
715 }
716 else
717 return os << std::format("SqlFixedString<{}> {{ size: {}, data: '{}' }}", N, value.size(), value.data());
718}
719
720template <std::size_t N, typename T>
721inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlDynamicString<N, T> const& value)
722{
723 if constexpr (std::same_as<T, char>)
724 return os << std::format("SqlDynamicString<{}> {{ size: {}, '{}' }}", N, value.size(), value.data());
725 else
726 {
727 auto u8String = ToUtf8(std::basic_string_view<T>(value.data(), value.size()));
728 return os << std::format("SqlDynamicString<{}, {}> {{ size: {}, '{}' }}",
729 N,
730 Reflection::TypeNameOf<T>,
731 value.size(),
732 (char const*) u8String.c_str()); // NOLINT(readability-redundant-string-cstr)
733 }
734}
735
736[[nodiscard]] inline std::string NormalizeText(std::string_view const& text)
737{
738 auto result = std::string(text);
739
740 // Remove any newlines and reduce all whitespace to a single space
741 result.erase(std::unique(result.begin(), // NOLINT(modernize-use-ranges)
742 result.end(),
743 [](char a, char b) { return std::isspace(a) && std::isspace(b); }),
744 result.end());
745
746 // trim lading and trailing whitespace
747 while (!result.empty() && std::isspace(result.front()))
748 result.erase(result.begin());
749
750 while (!result.empty() && std::isspace(result.back()))
751 result.pop_back();
752
753 return result;
754}
755
756[[nodiscard]] inline std::string NormalizeText(std::vector<std::string> const& texts)
757{
758 auto result = std::string {};
759 for (auto const& text: texts)
760 {
761 if (!result.empty())
762 result += '\n';
763 result += NormalizeText(text);
764 }
765 return result;
766}
767
768// }}}
769
770/// Enables foreign keys for SQLite (no-op for other DBMS).
771inline void EnableForeignKeysIfNeeded(Lightweight::SqlConnection& conn)
772{
773 if (conn.ServerType() == Lightweight::SqlServerType::SQLITE)
774 {
775 Lightweight::SqlStatement stmt { conn };
776 stmt.ExecuteDirect("PRAGMA foreign_keys = ON");
777 }
778}
779
780/// Disables foreign keys for SQLite (no-op for other DBMS).
781inline void DisableForeignKeysIfNeeded(Lightweight::SqlConnection& conn)
782{
783 if (conn.ServerType() == Lightweight::SqlServerType::SQLITE)
784 {
785 Lightweight::SqlStatement stmt { conn };
786 stmt.ExecuteDirect("PRAGMA foreign_keys = OFF");
787 }
788}
789
790/// Wraps a function with IDENTITY_INSERT ON/OFF for MS SQL.
791/// For other DBMS, the function is called directly.
792template <typename Func>
793inline void WithIdentityInsert(Lightweight::SqlStatement& stmt, std::string_view tableName, Func&& func)
794{
795 if (stmt.Connection().ServerType() == Lightweight::SqlServerType::MICROSOFT_SQL)
796 {
797 stmt.ExecuteDirect(std::format("SET IDENTITY_INSERT \"{}\" ON", tableName));
798 try
799 {
800 std::forward<Func>(func)();
801 stmt.ExecuteDirect(std::format("SET IDENTITY_INSERT \"{}\" OFF", tableName));
802 }
803 catch (...)
804 {
805 stmt.ExecuteDirect(std::format("SET IDENTITY_INSERT \"{}\" OFF", tableName));
806 throw;
807 }
808 }
809 else
810 {
811 std::forward<Func>(func)();
812 }
813}
814
815inline void CreateEmployeesTable(Lightweight::SqlStatement& stmt,
816 std::source_location location = std::source_location::current())
817{
818 stmt.MigrateDirect(
820 migration.CreateTable("Employees")
821 .PrimaryKeyWithAutoIncrement("EmployeeID")
822 .RequiredColumn("FirstName", Lightweight::SqlColumnTypeDefinitions::Varchar { 50 })
823 .Column("LastName", Lightweight::SqlColumnTypeDefinitions::Varchar { 50 })
824 .RequiredColumn("Salary", Lightweight::SqlColumnTypeDefinitions::Integer {});
825 },
826 location);
827}
828
829inline void CreateLargeTable(Lightweight::SqlStatement& stmt)
830{
832 auto table = migration.CreateTable("LargeTable");
833 for (char c = 'A'; c <= 'Z'; ++c)
834 {
835 table.Column(std::string(1, c), Lightweight::SqlColumnTypeDefinitions::Varchar { 50 });
836 }
837 });
838}
839
840inline void FillEmployeesTable(Lightweight::SqlStatement& stmt)
841{
842 stmt.Prepare(stmt.Query("Employees")
843 .Insert()
844 .Set("FirstName", Lightweight::SqlWildcard)
845 .Set("LastName", Lightweight::SqlWildcard)
846 .Set("Salary", Lightweight::SqlWildcard));
847 stmt.Execute("Alice", "Smith", 50'000);
848 stmt.Execute("Bob", "Johnson", 60'000);
849 stmt.Execute("Charlie", "Brown", 70'000);
850}
851
852template <typename T = char>
853inline auto MakeLargeText(size_t size)
854{
855 auto text = std::basic_string<T>(size, {});
856 std::ranges::generate(text, [i = 0]() mutable { return static_cast<T>('A' + (i++ % 26)); });
857 return text;
858}
859
860inline bool IsGithubActions()
861{
862#if defined(_WIN32) || defined(_WIN64)
863 char envBuffer[32] {};
864 size_t requiredCount = 0;
865 return getenv_s(&requiredCount, envBuffer, sizeof(envBuffer), "GITHUB_ACTIONS") == 0
866 && std::string_view(envBuffer) == "true" == 0;
867#else
868 return std::getenv("GITHUB_ACTIONS") != nullptr
869 && std::string_view(std::getenv("GITHUB_ACTIONS")) == "true"; // NOLINT(clang-analyzer-core.NonNullParamChecker)
870#endif
871}
872
873namespace std
874{
875
876template <typename T>
877ostream& operator<<(ostream& os, optional<T> const& opt)
878{
879 if (opt.has_value())
880 return os << *opt;
881 else
882 return os << "nullopt";
883}
884
885} // namespace std
886
887template <typename T, auto P1, auto P2>
888std::ostream& operator<<(std::ostream& os, Lightweight::Field<std::optional<T>, P1, P2> const& field)
889{
890 if (field.Value())
891 return os << std::format("Field<{}> {{ {}, {} }}",
892 Reflection::TypeNameOf<T>,
893 *field.Value(),
894 field.IsModified() ? "modified" : "not modified");
895 else
896 return os << "NULL";
897}
898
899template <typename T, auto P1, auto P2>
900std::ostream& operator<<(std::ostream& os, Lightweight::Field<T, P1, P2> const& field)
901{
902 return os << std::format("Field<{}> {{ ", Reflection::TypeNameOf<T>) << "value: " << field.Value() << "; "
903 << (field.IsModified() ? "modified" : "not modified") << " }";
904}
Represents a connection to a SQL database.
SqlServerType ServerType() const noexcept
Retrieves the type of the server.
static LIGHTWEIGHT_API void SetDefaultConnectionString(SqlConnectionString const &connectionString) noexcept
SQLHDBC NativeHandle() const noexcept
Retrieves the native handle.
static LIGHTWEIGHT_API void SetPostConnectedHook(std::function< void(SqlConnection &)> hook)
Sets a callback to be called after each connection being established.
LIGHTWEIGHT_API SqlCreateTableQueryBuilder & Column(SqlColumnDeclaration column)
Adds a new column to the table.
LIGHTWEIGHT_FORCE_INLINE constexpr decltype(auto) data(this auto &&self) noexcept
Retrieves the pointer to the string data.
LIGHTWEIGHT_FORCE_INLINE constexpr std::size_t size() const noexcept
Retrieves the size of the string.
LIGHTWEIGHT_FORCE_INLINE std::size_t size() const noexcept
Retrieves the string's size.
LIGHTWEIGHT_FORCE_INLINE T const * data() const noexcept
Retrieves the string's inner value (as T const*).
LIGHTWEIGHT_FORCE_INLINE constexpr std::size_t size() const noexcept
Returns the size of the string.
static LIGHTWEIGHT_API SqlLogger & TraceLogger()
Retrieves a logger that logs to the trace logger.
static LIGHTWEIGHT_API void SetLogger(SqlLogger &logger)
static LIGHTWEIGHT_API SqlLogger & StandardLogger()
Retrieves a logger that logs to standard output.
Query builder for building SQL migration queries.
Definition Migrate.hpp:430
LIGHTWEIGHT_API SqlCreateTableQueryBuilder CreateTable(std::string_view tableName)
Creates a new table.
LIGHTWEIGHT_API SqlInsertQueryBuilder Insert(std::vector< SqlVariant > *boundInputs=nullptr) noexcept
High level API for (prepared) raw SQL statements.
LIGHTWEIGHT_API void Prepare(std::string_view query) &
LIGHTWEIGHT_API SQLHSTMT NativeHandle() const noexcept
Retrieves the native handle of the statement.
void MigrateDirect(Callable const &callable, std::source_location location=std::source_location::current())
Executes an SQL migration query, as created b the callback.
LIGHTWEIGHT_API SqlConnection & Connection() noexcept
Retrieves the connection associated with this statement.
void Execute(Args const &... args)
Binds the given arguments to the prepared statement and executes it.
bool GetColumn(SQLUSMALLINT column, T *result) const
LIGHTWEIGHT_API SqlQueryBuilder Query(std::string_view const &table={}) const
Creates a new query builder for the given table, compatible with the SQL server being connected.
LIGHTWEIGHT_API void ExecuteDirect(std::string_view const &query, std::source_location location=std::source_location::current())
Executes the given query directly.
LIGHTWEIGHT_API bool FetchRow()
LIGHTWEIGHT_API std::u8string ToUtf8(std::u32string_view u32InputString)
Represents a single column in a table.
Definition Field.hpp:84
constexpr bool IsModified() const noexcept
Checks if the field has been modified.
Definition Field.hpp:311
constexpr T const & Value() const noexcept
Returns the value of the field.
Definition Field.hpp:317
Represents an ODBC connection string.
constexpr LIGHTWEIGHT_FORCE_INLINE native_type value() const noexcept
Returns the current date and time.
LIGHTWEIGHT_FORCE_INLINE constexpr std::chrono::year_month_day value() const noexcept
Returns the current date.
Definition SqlDate.hpp:29
Represents an ODBC SQL error.
Definition SqlError.hpp:33
constexpr LIGHTWEIGHT_FORCE_INLINE auto ToUnscaledValue() const noexcept
Converts the numeric to an unscaled integer value.