Lightweight 0.20251202.0
Loading...
Searching...
No Matches
Utils.hpp
1// SPDX-License-Identifier: Apache-2.0
2
3#pragma once
4
5#if defined(_WIN32) || defined(_WIN64)
6 #include <Windows.h>
7#endif
8
9#include <Lightweight/Lightweight.hpp>
10
11#include <catch2/catch_session.hpp>
12#include <catch2/catch_test_macros.hpp>
13#include <yaml-cpp/yaml.h>
14
15#include <algorithm>
16#include <filesystem>
17#include <chrono>
18#include <format>
19#include <mutex>
20#include <ostream>
21#include <ranges>
22#include <string>
23#include <string_view>
24#include <tuple>
25#include <variant>
26#include <vector>
27
28#if __has_include(<stacktrace>)
29 #include <stacktrace>
30#endif
31
32#include <sql.h>
33#include <sqlext.h>
34#include <sqlspi.h>
35#include <sqltypes.h>
36
37// NOTE: I've done this preprocessor stuff only to have a single test for UTF-16 (UCS-2) regardless of platform.
38using WideChar = std::conditional_t<sizeof(wchar_t) == 2, wchar_t, char16_t>;
39using WideString = std::basic_string<WideChar>;
40using WideStringView = std::basic_string_view<WideChar>;
41
42#if defined(LIGHTWEIGHT_CXX26_REFLECTION)
43 /// @brief marco to define a member to the structure, in case of C++26 reflection this
44 /// will create reflection, in case of C++20 reflection this will create a member pointer
45 #define Member(x) ^^x
46
47#else
48 /// @brief marco to define a member to the structure, in case of C++26 reflection this
49 /// will create reflection, in case of C++20 reflection this will create a member pointer
50 #define Member(x) &x
51#endif
52
53#if !defined(_WIN32)
54 #define WTEXT(x) (u##x)
55#else
56 #define WTEXT(x) (L##x)
57#endif
58
59#define UNSUPPORTED_DATABASE(stmt, dbType) \
60 if ((stmt).Connection().ServerType() == (dbType)) \
61 { \
62 WARN(std::format("TODO({}): This database is currently unsupported on this test.", dbType)); \
63 return; \
64 }
65
66template <>
67struct std::formatter<std::u8string>: std::formatter<std::string>
68{
69 auto format(std::u8string const& value, std::format_context& ctx) const -> std::format_context::iterator
70 {
71 return std::formatter<std::string>::format(
72 std::format("{}", (char const*) value.c_str()), // NOLINT(readability-redundant-string-cstr)
73 ctx);
74 }
75};
76
77namespace std
78{
79
80// Add support for std::basic_string<WideChar> and std::basic_string_view<WideChar> to std::ostream,
81// so that we can get them pretty-printed in REQUIRE() and CHECK() macros.
82
83template <typename WideStringT>
84 requires(Lightweight::detail::OneOf<WideStringT,
85 std::wstring,
86 std::wstring_view,
87 std::u16string,
88 std::u16string_view,
89 std::u32string,
90 std::u32string_view>)
91ostream& operator<<(ostream& os, WideStringT const& str)
92{
93 auto constexpr BitsPerChar = sizeof(typename WideStringT::value_type) * 8;
94 auto const u8String = Lightweight::ToUtf8(str);
95 return os << "UTF-" << BitsPerChar << '{' << "length: " << str.size() << ", characters: " << '"'
96 << string_view((char const*) u8String.data(), u8String.size()) << '"' << '}';
97}
98
99inline ostream& operator<<(ostream& os, Lightweight::SqlGuid const& guid)
100{
101 return os << format("SqlGuid({})", guid);
102}
103
104inline std::string EscapedBinaryText(std::string_view binary)
105{
106 std::string hexEncodedString;
107 for (auto const& b: binary)
108 {
109 if (std::isprint(b))
110 hexEncodedString += static_cast<char>(b);
111 else
112 hexEncodedString += std::format("\\x{:02x}", static_cast<unsigned char>(b));
113 }
114 return hexEncodedString;
115}
116
117template <size_t N>
118inline ostream& operator<<(ostream& os, Lightweight::SqlDynamicBinary<N> const& binary)
119{
120 auto const hexEncodedString = EscapedBinaryText(std::string_view((char const*) binary.data(), binary.size()));
121 return os << std::format("SqlDynamicBinary<{}>(length: {}, characters: {})", N, binary.size(), hexEncodedString);
122}
123
124} // namespace std
125
126namespace std
127{
128template <std::size_t Precision, std::size_t Scale>
129std::ostream& operator<<(std::ostream& os, Lightweight::SqlNumeric<Precision, Scale> const& value)
130{
131 return os << std::format("SqlNumeric<{}, {}>({}, {}, {}, {})",
132 Precision,
133 Scale,
134 value.sqlValue.sign,
135 value.sqlValue.precision,
136 value.sqlValue.scale,
137 value.ToUnscaledValue());
138}
139} // namespace std
140
141// Refer to an in-memory SQLite database (and assuming the sqliteodbc driver is installed)
142// See:
143// - https://www.sqlite.org/inmemorydb.html
144// - http://www.ch-werner.de/sqliteodbc/
145// - https://github.com/softace/sqliteodbc
146//
147
148// clang-format off
149auto inline const DefaultTestConnectionString = Lightweight::SqlConnectionString { //NOLINT(bugprone-throwing-static-initialization)
150 // clang-format on
151 .value = std::format("DRIVER={};Database={}",
152#if defined(_WIN32) || defined(_WIN64)
153 "SQLite3 ODBC Driver",
154#else
155 "SQLite3",
156#endif
157 "test.db"),
158};
159
160class TestSuiteSqlLogger: public Lightweight::SqlLogger::Null
161{
162 private:
163 mutable std::mutex m_mutex;
164 std::string m_lastPreparedQuery;
165
166 void WriteRawInfo(std::string_view message);
167
168 template <typename... Args>
169 void WriteInfo(std::format_string<Args...> const& fmt, Args&&... args)
170 {
171 auto message = std::format(fmt, std::forward<Args>(args)...);
172 message = std::format("[{}] {}", "Lightweight", message);
173 WriteRawInfo(message);
174 }
175
176 template <typename... Args>
177 void WriteWarning(std::format_string<Args...> const& fmt, Args&&... args)
178 {
179 WARN(std::format(fmt, std::forward<Args>(args)...));
180 }
181
182 public:
183 static TestSuiteSqlLogger& GetLogger() noexcept
184 {
185 static TestSuiteSqlLogger theLogger;
186 return theLogger;
187 }
188
189 void OnError(Lightweight::SqlError error, std::source_location sourceLocation) override
190 {
191 WriteWarning("SQL Error: {}", error);
192 WriteDetails(sourceLocation);
193 }
194
195 void OnError(Lightweight::SqlErrorInfo const& errorInfo, std::source_location sourceLocation) override
196 {
197 WriteWarning("SQL Error: {}", errorInfo);
198 WriteDetails(sourceLocation);
199 }
200
201 void OnWarning(std::string_view const& message) override
202 {
203 WriteWarning("{}", message);
204 WriteDetails(std::source_location::current());
205 }
206
207 void OnExecuteDirect(std::string_view const& query) override
208 {
209 WriteInfo("ExecuteDirect: {}", query);
210 }
211
212 void OnPrepare(std::string_view const& query) override
213 {
214 std::scoped_lock lock(m_mutex);
215 m_lastPreparedQuery = query;
216 }
217
218 void OnExecute(std::string_view const& query) override
219 {
220 WriteInfo("Execute: {}", query);
221 }
222
223 void OnExecuteBatch() override
224 {
225 std::scoped_lock lock(m_mutex);
226 WriteInfo("ExecuteBatch: {}", m_lastPreparedQuery);
227 }
228
229 void OnFetchRow() override
230 {
231 WriteInfo("Fetched row");
232 }
233
234 void OnFetchEnd() override
235 {
236 // WriteInfo("Fetch end");
237 }
238
239 private:
240 void WriteDetails(std::source_location sourceLocation)
241 {
242 std::scoped_lock lock(m_mutex);
243 WriteInfo(" Source: {}:{}", sourceLocation.file_name(), sourceLocation.line());
244 if (!m_lastPreparedQuery.empty())
245 WriteInfo(" Query: {}", m_lastPreparedQuery);
246 WriteInfo(" Stack trace:");
247
248#if __has_include(<stacktrace>)
249 auto stackTrace = std::stacktrace::current(1, 25);
250 for (std::size_t const i: std::views::iota(std::size_t(0), stackTrace.size()))
251 WriteInfo(" [{:>2}] {}", i, stackTrace[i]);
252#endif
253 }
254};
255
256// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
257class ScopedSqlNullLogger: public Lightweight::SqlLogger::Null
258{
259 private:
260 SqlLogger& m_previousLogger = SqlLogger::GetLogger();
261
262 public:
263 ScopedSqlNullLogger()
264 {
265 SqlLogger::SetLogger(*this);
266 }
267
268 ~ScopedSqlNullLogger() override
269 {
270 SqlLogger::SetLogger(m_previousLogger);
271 }
272};
273
274template <typename Getter, typename Callable>
275constexpr void FixedPointIterate(Getter const& getter, Callable const& callable)
276{
277 auto a = getter();
278 for (;;)
279 {
280 callable(a);
281 auto b = getter();
282 if (a == b)
283 break;
284 a = std::move(b);
285 }
286}
287
288// Searches upward from current working directory for .test-env.yml
289// Stops at directory containing .git (project root boundary) or filesystem root
290inline std::optional<std::filesystem::path> FindTestEnvFile()
291{
292 auto currentDir = std::filesystem::current_path();
293
294 while (true)
295 {
296 auto testEnvPath = currentDir / ".test-env.yml";
297 if (std::filesystem::exists(testEnvPath))
298 return testEnvPath;
299
300 // Stop at project root (directory containing .git) or filesystem root
301 auto gitPath = currentDir / ".git";
302 if (std::filesystem::exists(gitPath))
303 return std::nullopt; // Reached project root, file not found
304
305 auto parentDir = currentDir.parent_path();
306 if (parentDir == currentDir)
307 return std::nullopt; // Reached filesystem root
308
309 currentDir = parentDir;
310 }
311}
312
313// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
314class SqlTestFixture
315{
316 public:
317 static inline std::string testDatabaseName = "LightweightTest"; // NOLINT(bugprone-throwing-static-initialization)
318 static inline bool odbcTrace = false;
319 static inline std::atomic<bool> running = false;
320
321 using MainProgramArgs = std::tuple<int, char**>;
322
323 static std::variant<MainProgramArgs, int> Initialize(int argc, char** argv)
324 {
325 Lightweight::SqlLogger::SetLogger(TestSuiteSqlLogger::GetLogger());
326
327 using namespace std::string_view_literals;
328 std::optional<std::string> testEnvName;
329 int i = 1;
330 for (; i < argc; ++i)
331 {
332 if (argv[i] == "--trace-sql"sv)
334 else if (argv[i] == "--trace-odbc"sv)
335 odbcTrace = true;
336 else if (std::string_view(argv[i]).starts_with("--test-env="))
337 testEnvName = std::string_view(argv[i]).substr(11);
338 else if (argv[i] == "--help"sv || argv[i] == "-h"sv)
339 {
340 std::println("{} [--test-env=NAME] [--trace-sql] [--trace-odbc] [[--] [Catch2 flags ...]]", argv[0]);
341 std::println("");
342 std::println("Options:");
343 std::println(" --test-env=NAME Use connection string from .test-env.yml (e.g., pgsql, mssql, sqlite)");
344 std::println(" --trace-sql Enable SQL tracing");
345 std::println(" --trace-odbc Enable ODBC tracing");
346 return { EXIT_SUCCESS };
347 }
348 else if (argv[i] == "--"sv)
349 {
350 ++i;
351 break;
352 }
353 else
354 break;
355 }
356
357 if (i < argc)
358 argv[i - 1] = argv[0];
359
360 // Connection string priority:
361 // 1. --test-env=NAME (must find valid entry, error if not found)
362 // 2. Environment variable ODBC_CONNECTION_STRING
363 // 3. Default SQLite3
364
365 if (testEnvName)
366 {
367 auto configPath = FindTestEnvFile();
368 if (!configPath)
369 {
370 std::println(stderr,
371 "Error: .test-env.yml not found (searched from '{}' to project root)",
372 std::filesystem::current_path().string());
373 return { EXIT_FAILURE };
374 }
375
376 try
377 {
378 YAML::Node config = YAML::LoadFile(configPath->string());
379 auto connectionStrings = config["ODBC_CONNECTION_STRING"];
380 if (!connectionStrings || !connectionStrings[*testEnvName])
381 {
382 std::println(stderr, "Error: Key '{}' not found in ODBC_CONNECTION_STRING", *testEnvName);
383 if (connectionStrings && connectionStrings.IsMap())
384 {
385 std::print(stderr, "Available environments:");
386 for (auto const& entry: connectionStrings)
387 std::print(stderr, " {}", entry.first.as<std::string>());
388 std::println(stderr, "");
389 }
390 return { EXIT_FAILURE };
391 }
392 auto connStr = connectionStrings[*testEnvName].as<std::string>();
393 if (connStr.empty())
394 {
395 std::println(stderr, "Error: Connection string for '{}' is empty", *testEnvName);
396 return { EXIT_FAILURE };
397 }
398 std::println("Using test environment '{}' from: {}", *testEnvName, configPath->string());
399 std::println("Using ODBC connection string: '{}'", Lightweight::SqlConnectionString::SanitizePwd(connStr));
401 }
402 catch (YAML::Exception const& e)
403 {
404 std::println(stderr, "Error parsing {}: {}", configPath->string(), e.what());
405 return { EXIT_FAILURE };
406 }
407 }
408 else
409 {
410#if defined(_MSC_VER)
411 char* envBuffer = nullptr;
412 size_t envBufferLen = 0;
413 _dupenv_s(&envBuffer, &envBufferLen, "ODBC_CONNECTION_STRING");
414 if (auto const* s = envBuffer; s && *s)
415#else
416 if (auto const* s = std::getenv("ODBC_CONNECTION_STRING"); s && *s)
417#endif
418
419 {
420 std::println("Using ODBC connection string: '{}'", Lightweight::SqlConnectionString::SanitizePwd(s));
422 }
423 else
424 {
425 // Use an in-memory SQLite3 database by default (for testing purposes)
426 std::println("Using default ODBC connection string: '{}'", DefaultTestConnectionString.value);
428 }
429 }
430
431 Lightweight::SqlConnection::SetPostConnectedHook(&SqlTestFixture::PostConnectedHook);
432
433 auto sqlConnection = Lightweight::SqlConnection();
434 if (!sqlConnection.IsAlive())
435 {
436 std::println("Failed to connect to the database: {}", sqlConnection.LastError());
437 std::abort();
438 }
439
440 std::println("Running test cases against: {} ({}) (identified as: {})",
441 sqlConnection.ServerName(),
442 sqlConnection.ServerVersion(),
443 sqlConnection.ServerType());
444
445 return MainProgramArgs { argc - (i - 1), argv + (i - 1) };
446 }
447
448 static void PostConnectedHook(Lightweight::SqlConnection& connection)
449 {
450 if (odbcTrace)
451 {
452 auto const traceFile = []() -> std::string_view {
453#if !defined(_WIN32) && !defined(_WIN64)
454 return "/dev/stdout";
455#else
456 return "CONOUT$";
457#endif
458 }();
459
460 SQLHDBC handle = connection.NativeHandle();
461 SQLSetConnectAttrA(handle, SQL_ATTR_TRACEFILE, (SQLPOINTER) traceFile.data(), SQL_NTS);
462 SQLSetConnectAttrA(handle, SQL_ATTR_TRACE, (SQLPOINTER) SQL_OPT_TRACE_ON, SQL_IS_UINTEGER);
463 }
464
465 using Lightweight::SqlServerType;
466 switch (connection.ServerType())
467 {
468 case SqlServerType::SQLITE: {
469 auto stmt = Lightweight::SqlStatement { connection };
470 // Enable foreign key constraints for SQLite
471 stmt.ExecuteDirect("PRAGMA foreign_keys = ON");
472 break;
473 }
474 case SqlServerType::MICROSOFT_SQL:
475 case SqlServerType::POSTGRESQL:
476 case SqlServerType::MYSQL:
477 case SqlServerType::UNKNOWN:
478 break;
479 }
480 }
481
482 SqlTestFixture()
483 {
484 running = true;
485 auto stmt = Lightweight::SqlStatement();
486 REQUIRE(stmt.IsAlive());
487
488 char dbName[100]; // Buffer to store the database name
489 SQLSMALLINT dbNameLen {};
490 SQLGetInfo(stmt.Connection().NativeHandle(), SQL_DATABASE_NAME, dbName, sizeof(dbName), &dbNameLen);
491 if (dbNameLen > 0)
492 testDatabaseName = dbName;
493
494 DropAllTablesInDatabase(stmt);
495 }
496
497 virtual ~SqlTestFixture()
498 {
499 running = false;
501 }
502
503 static std::string ToString(std::vector<std::string> const& values, std::string_view separator)
504 {
505 auto result = std::string {};
506 for (auto const& value: values)
507 {
508 if (!result.empty())
509 result += separator;
510 result += value;
511 }
512 return result;
513 }
514
515 static void DropTableRecursively(Lightweight::SqlStatement& stmt,
516 Lightweight::SqlSchema::FullyQualifiedTableName const& table)
517 {
518 auto const dependantTables = Lightweight::SqlSchema::AllForeignKeysTo(stmt, table);
519 for (auto const& dependantTable: dependantTables)
520 DropTableRecursively(stmt, dependantTable.foreignKey.table);
521 stmt.ExecuteDirect(std::format("DROP TABLE IF EXISTS \"{}\"", table.table));
522 }
523
524 static void DropTableIfExists(Lightweight::SqlConnection& conn, std::string const& tableName)
525 {
526 Lightweight::SqlStatement stmt { conn };
527 try
528 {
529 stmt.ExecuteDirect(std::format("DROP TABLE IF EXISTS {}", tableName));
530 }
531 catch (...)
532 {
533 ; // ignore
534 }
535 }
536
537 static void DropAllTablesInDatabase(Lightweight::SqlStatement& stmt)
538 {
539 using Lightweight::SqlServerType;
540 switch (stmt.Connection().ServerType())
541 {
542 case SqlServerType::MICROSOFT_SQL:
543 case SqlServerType::MYSQL:
544 stmt.ExecuteDirect(std::format("USE \"{}\"", testDatabaseName));
545 [[fallthrough]];
546 case SqlServerType::SQLITE:
547 case SqlServerType::UNKNOWN: {
548 auto const tableNames = GetAllTableNames(stmt);
549 for (auto const& tableName: tableNames)
550 {
551 if (tableName == "sqlite_sequence")
552 continue;
553
554 DropTableRecursively(stmt,
555 Lightweight::SqlSchema::FullyQualifiedTableName {
556 .catalog = {},
557 .schema = {},
558 .table = tableName,
559 });
560 }
561 break;
562 }
563 case SqlServerType::POSTGRESQL:
564 if (m_createdTables.empty())
565 m_createdTables = GetAllTableNames(stmt);
566 for (auto& createdTable: std::views::reverse(m_createdTables))
567 stmt.ExecuteDirect(std::format("DROP TABLE IF EXISTS \"{}\" CASCADE", createdTable));
568 break;
569 }
570 m_createdTables.clear();
571 }
572
573 static std::string GetDefaultSchemaName(Lightweight::SqlConnection const& connection)
574 {
575 using namespace std::string_literals;
576 switch (connection.ServerType())
577 {
578 case Lightweight::SqlServerType::MICROSOFT_SQL:
579 return "dbo"s;
580 default:
581 return ""s;
582 }
583 }
584
585 private:
586 static std::vector<std::string> GetAllTableNames(Lightweight::SqlStatement& stmt)
587 {
588 using namespace std::string_literals;
589 auto result = std::vector<std::string>();
590 auto const schemaName = GetDefaultSchemaName(stmt.Connection());
591 auto const sqlResult = SQLTables(stmt.NativeHandle(),
592 (SQLCHAR*) testDatabaseName.data(),
593 (SQLSMALLINT) testDatabaseName.size(),
594 (SQLCHAR*) schemaName.data(),
595 (SQLSMALLINT) schemaName.size(),
596 nullptr,
597 0,
598 (SQLCHAR*) "TABLE",
599 SQL_NTS);
600 if (SQL_SUCCEEDED(sqlResult))
601 {
602 while (stmt.FetchRow())
603 {
604 result.emplace_back(stmt.GetColumn<std::string>(3)); // table name
605 }
606 }
607 return result;
608 }
609
610 static inline std::vector<std::string> m_createdTables;
611};
612
613// {{{ ostream support for Lightweight, for debugging purposes
614inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlText const& value)
615{
616 return os << std::format("SqlText({})", value.value);
617}
618
619inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlDate const& date)
620{
621 auto const ymd = date.value();
622 return os << std::format("SqlDate {{ {}-{}-{} }}", ymd.year(), ymd.month(), ymd.day());
623}
624
625inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlTime const& time)
626{
627 auto const value = time.value();
628 return os << std::format("SqlTime {{ {:02}:{:02}:{:02}.{:06} }}",
629 value.hours().count(),
630 value.minutes().count(),
631 value.seconds().count(),
632 value.subseconds().count());
633}
634
635inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlDateTime const& datetime)
636{
637 auto const value = datetime.value();
638 auto const totalDays = std::chrono::floor<std::chrono::days>(value);
639 auto const ymd = std::chrono::year_month_day { totalDays };
640 auto const hms =
641 std::chrono::hh_mm_ss<std::chrono::nanoseconds> { std::chrono::floor<std::chrono::nanoseconds>(value - totalDays) };
642 return os << std::format("SqlDateTime {{ {:04}-{:02}-{:02} {:02}:{:02}:{:02}.{:09} }}",
643 (int) ymd.year(),
644 (unsigned) ymd.month(),
645 (unsigned) ymd.day(),
646 hms.hours().count(),
647 hms.minutes().count(),
648 hms.seconds().count(),
649 hms.subseconds().count());
650}
651
652template <std::size_t N, typename T, Lightweight::SqlFixedStringMode Mode>
653inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlFixedString<N, T, Mode> const& value)
654{
655 if constexpr (Mode == Lightweight::SqlFixedStringMode::FIXED_SIZE)
656 return os << std::format("SqlFixedString<{}> {{ size: {}, data: '{}' }}", N, value.size(), value.data());
657 else if constexpr (Mode == Lightweight::SqlFixedStringMode::FIXED_SIZE_RIGHT_TRIMMED)
658 return os << std::format("SqlTrimmedFixedString<{}> {{ '{}' }}", N, value.data());
659 else if constexpr (Mode == Lightweight::SqlFixedStringMode::VARIABLE_SIZE)
660 {
661 if constexpr (std::same_as<T, char>)
662 return os << std::format("SqlVariableString<{}> {{ size: {}, '{}' }}", N, value.size(), value.data());
663 else
664 {
665 auto u8String = ToUtf8(std::basic_string_view<T>(value.data(), value.size()));
666 return os << std::format("SqlVariableString<{}, {}> {{ size: {}, '{}' }}",
667 N,
668 Reflection::TypeNameOf<T>,
669 value.size(),
670 (char const*) u8String.c_str()); // NOLINT(readability-redundant-string-cstr)
671 }
672 }
673 else
674 return os << std::format("SqlFixedString<{}> {{ size: {}, data: '{}' }}", N, value.size(), value.data());
675}
676
677template <std::size_t N, typename T>
678inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlDynamicString<N, T> const& value)
679{
680 if constexpr (std::same_as<T, char>)
681 return os << std::format("SqlDynamicString<{}> {{ size: {}, '{}' }}", N, value.size(), value.data());
682 else
683 {
684 auto u8String = ToUtf8(std::basic_string_view<T>(value.data(), value.size()));
685 return os << std::format("SqlDynamicString<{}, {}> {{ size: {}, '{}' }}",
686 N,
687 Reflection::TypeNameOf<T>,
688 value.size(),
689 (char const*) u8String.c_str()); // NOLINT(readability-redundant-string-cstr)
690 }
691}
692
693[[nodiscard]] inline std::string NormalizeText(std::string_view const& text)
694{
695 auto result = std::string(text);
696
697 // Remove any newlines and reduce all whitespace to a single space
698 result.erase(std::unique(result.begin(), // NOLINT(modernize-use-ranges)
699 result.end(),
700 [](char a, char b) { return std::isspace(a) && std::isspace(b); }),
701 result.end());
702
703 // trim lading and trailing whitespace
704 while (!result.empty() && std::isspace(result.front()))
705 result.erase(result.begin());
706
707 while (!result.empty() && std::isspace(result.back()))
708 result.pop_back();
709
710 return result;
711}
712
713[[nodiscard]] inline std::string NormalizeText(std::vector<std::string> const& texts)
714{
715 auto result = std::string {};
716 for (auto const& text: texts)
717 {
718 if (!result.empty())
719 result += '\n';
720 result += NormalizeText(text);
721 }
722 return result;
723}
724
725// }}}
726
727/// Enables foreign keys for SQLite (no-op for other DBMS).
728inline void EnableForeignKeysIfNeeded(Lightweight::SqlConnection& conn)
729{
730 if (conn.ServerType() == Lightweight::SqlServerType::SQLITE)
731 {
732 Lightweight::SqlStatement stmt { conn };
733 stmt.ExecuteDirect("PRAGMA foreign_keys = ON");
734 }
735}
736
737/// Disables foreign keys for SQLite (no-op for other DBMS).
738inline void DisableForeignKeysIfNeeded(Lightweight::SqlConnection& conn)
739{
740 if (conn.ServerType() == Lightweight::SqlServerType::SQLITE)
741 {
742 Lightweight::SqlStatement stmt { conn };
743 stmt.ExecuteDirect("PRAGMA foreign_keys = OFF");
744 }
745}
746
747/// Wraps a function with IDENTITY_INSERT ON/OFF for MS SQL.
748/// For other DBMS, the function is called directly.
749template <typename Func>
750inline void WithIdentityInsert(Lightweight::SqlStatement& stmt, std::string_view tableName, Func&& func)
751{
752 if (stmt.Connection().ServerType() == Lightweight::SqlServerType::MICROSOFT_SQL)
753 {
754 stmt.ExecuteDirect(std::format("SET IDENTITY_INSERT \"{}\" ON", tableName));
755 try
756 {
757 std::forward<Func>(func)();
758 stmt.ExecuteDirect(std::format("SET IDENTITY_INSERT \"{}\" OFF", tableName));
759 }
760 catch (...)
761 {
762 stmt.ExecuteDirect(std::format("SET IDENTITY_INSERT \"{}\" OFF", tableName));
763 throw;
764 }
765 }
766 else
767 {
768 std::forward<Func>(func)();
769 }
770}
771
772inline void CreateEmployeesTable(Lightweight::SqlStatement& stmt,
773 std::source_location location = std::source_location::current())
774{
775 stmt.MigrateDirect(
777 migration.CreateTable("Employees")
778 .PrimaryKeyWithAutoIncrement("EmployeeID")
779 .RequiredColumn("FirstName", Lightweight::SqlColumnTypeDefinitions::Varchar { 50 })
780 .Column("LastName", Lightweight::SqlColumnTypeDefinitions::Varchar { 50 })
781 .RequiredColumn("Salary", Lightweight::SqlColumnTypeDefinitions::Integer {});
782 },
783 location);
784}
785
786inline void CreateLargeTable(Lightweight::SqlStatement& stmt)
787{
789 auto table = migration.CreateTable("LargeTable");
790 for (char c = 'A'; c <= 'Z'; ++c)
791 {
792 table.Column(std::string(1, c), Lightweight::SqlColumnTypeDefinitions::Varchar { 50 });
793 }
794 });
795}
796
797inline void FillEmployeesTable(Lightweight::SqlStatement& stmt)
798{
799 stmt.Prepare(stmt.Query("Employees")
800 .Insert()
801 .Set("FirstName", Lightweight::SqlWildcard)
802 .Set("LastName", Lightweight::SqlWildcard)
803 .Set("Salary", Lightweight::SqlWildcard));
804 stmt.Execute("Alice", "Smith", 50'000);
805 stmt.Execute("Bob", "Johnson", 60'000);
806 stmt.Execute("Charlie", "Brown", 70'000);
807}
808
809template <typename T = char>
810inline auto MakeLargeText(size_t size)
811{
812 auto text = std::basic_string<T>(size, {});
813 std::ranges::generate(text, [i = 0]() mutable { return static_cast<T>('A' + (i++ % 26)); });
814 return text;
815}
816
817inline bool IsGithubActions()
818{
819#if defined(_WIN32) || defined(_WIN64)
820 char envBuffer[32] {};
821 size_t requiredCount = 0;
822 return getenv_s(&requiredCount, envBuffer, sizeof(envBuffer), "GITHUB_ACTIONS") == 0
823 && std::string_view(envBuffer) == "true" == 0;
824#else
825 return std::getenv("GITHUB_ACTIONS") != nullptr
826 && std::string_view(std::getenv("GITHUB_ACTIONS")) == "true"; // NOLINT(clang-analyzer-core.NonNullParamChecker)
827#endif
828}
829
830namespace std
831{
832
833template <typename T>
834ostream& operator<<(ostream& os, optional<T> const& opt)
835{
836 if (opt.has_value())
837 return os << *opt;
838 else
839 return os << "nullopt";
840}
841
842} // namespace std
843
844template <typename T, auto P1, auto P2>
845std::ostream& operator<<(std::ostream& os, Lightweight::Field<std::optional<T>, P1, P2> const& field)
846{
847 if (field.Value())
848 return os << std::format("Field<{}> {{ {}, {} }}",
849 Reflection::TypeNameOf<T>,
850 *field.Value(),
851 field.IsModified() ? "modified" : "not modified");
852 else
853 return os << "NULL";
854}
855
856template <typename T, auto P1, auto P2>
857std::ostream& operator<<(std::ostream& os, Lightweight::Field<T, P1, P2> const& field)
858{
859 return os << std::format("Field<{}> {{ ", Reflection::TypeNameOf<T>) << "value: " << field.Value() << "; "
860 << (field.IsModified() ? "modified" : "not modified") << " }";
861}
Represents a connection to a SQL database.
SqlServerType ServerType() const noexcept
Retrieves the type of the server.
static LIGHTWEIGHT_API void SetDefaultConnectionString(SqlConnectionString const &connectionString) noexcept
SQLHDBC NativeHandle() const noexcept
Retrieves the native handle.
static LIGHTWEIGHT_API void SetPostConnectedHook(std::function< void(SqlConnection &)> hook)
Sets a callback to be called after each connection being established.
LIGHTWEIGHT_API SqlCreateTableQueryBuilder & Column(SqlColumnDeclaration column)
Adds a new column to the table.
LIGHTWEIGHT_FORCE_INLINE constexpr decltype(auto) data(this auto &&self) noexcept
Retrieves the pointer to the string data.
LIGHTWEIGHT_FORCE_INLINE constexpr std::size_t size() const noexcept
Retrieves the size of the string.
LIGHTWEIGHT_FORCE_INLINE std::size_t size() const noexcept
Retrieves the string's size.
LIGHTWEIGHT_FORCE_INLINE T const * data() const noexcept
Retrieves the string's inner value (as T const*).
LIGHTWEIGHT_FORCE_INLINE constexpr std::size_t size() const noexcept
Returns the size of the string.
static LIGHTWEIGHT_API SqlLogger & TraceLogger()
Retrieves a logger that logs to the trace logger.
static LIGHTWEIGHT_API void SetLogger(SqlLogger &logger)
static LIGHTWEIGHT_API SqlLogger & StandardLogger()
Retrieves a logger that logs to standard output.
Query builder for building SQL migration queries.
Definition Migrate.hpp:218
LIGHTWEIGHT_API SqlCreateTableQueryBuilder CreateTable(std::string_view tableName)
Creates a new table.
LIGHTWEIGHT_API SqlInsertQueryBuilder Insert(std::vector< SqlVariant > *boundInputs=nullptr) noexcept
High level API for (prepared) raw SQL statements.
LIGHTWEIGHT_API void Prepare(std::string_view query) &
LIGHTWEIGHT_API SQLHSTMT NativeHandle() const noexcept
Retrieves the native handle of the statement.
void MigrateDirect(Callable const &callable, std::source_location location=std::source_location::current())
Executes an SQL migration query, as created b the callback.
LIGHTWEIGHT_API SqlConnection & Connection() noexcept
Retrieves the connection associated with this statement.
void Execute(Args const &... args)
Binds the given arguments to the prepared statement and executes it.
bool GetColumn(SQLUSMALLINT column, T *result) const
LIGHTWEIGHT_API SqlQueryBuilder Query(std::string_view const &table={}) const
Creates a new query builder for the given table, compatible with the SQL server being connected.
LIGHTWEIGHT_API void ExecuteDirect(std::string_view const &query, std::source_location location=std::source_location::current())
Executes the given query directly.
LIGHTWEIGHT_API bool FetchRow()
LIGHTWEIGHT_API std::u8string ToUtf8(std::u32string_view u32InputString)
Represents a single column in a table.
Definition Field.hpp:84
constexpr bool IsModified() const noexcept
Checks if the field has been modified.
Definition Field.hpp:311
constexpr T const & Value() const noexcept
Returns the value of the field.
Definition Field.hpp:317
Represents an ODBC connection string.
constexpr LIGHTWEIGHT_FORCE_INLINE native_type value() const noexcept
Returns the current date and time.
LIGHTWEIGHT_FORCE_INLINE constexpr std::chrono::year_month_day value() const noexcept
Returns the current date.
Definition SqlDate.hpp:29
Represents an ODBC SQL error.
Definition SqlError.hpp:33
constexpr LIGHTWEIGHT_FORCE_INLINE auto ToUnscaledValue() const noexcept
Converts the numeric to an unscaled integer value.