Lightweight 0.20251202.0
Loading...
Searching...
No Matches
Utils.hpp
1// SPDX-License-Identifier: Apache-2.0
2
3#pragma once
4
5#if defined(_WIN32) || defined(_WIN64)
6 #include <Windows.h>
7#endif
8
9#include <Lightweight/Lightweight.hpp>
10
11#include <catch2/catch_session.hpp>
12#include <catch2/catch_test_macros.hpp>
13#include <yaml-cpp/yaml.h>
14
15#include <algorithm>
16#include <filesystem>
17#include <chrono>
18#include <format>
19#include <mutex>
20#include <ostream>
21#include <ranges>
22#include <string>
23#include <string_view>
24#include <tuple>
25#include <variant>
26#include <vector>
27
28#if __has_include(<stacktrace>)
29 #include <stacktrace>
30#endif
31
32#include <sql.h>
33#include <sqlext.h>
34#include <sqlspi.h>
35#include <sqltypes.h>
36
37// NOTE: I've done this preprocessor stuff only to have a single test for UTF-16 (UCS-2) regardless of platform.
38using WideChar = std::conditional_t<sizeof(wchar_t) == 2, wchar_t, char16_t>;
39using WideString = std::basic_string<WideChar>;
40using WideStringView = std::basic_string_view<WideChar>;
41
42#if defined(LIGHTWEIGHT_CXX26_REFLECTION)
43 /// @brief marco to define a member to the structure, in case of C++26 reflection this
44 /// will create reflection, in case of C++20 reflection this will create a member pointer
45 #define Member(x) ^^x
46
47#else
48 /// @brief marco to define a member to the structure, in case of C++26 reflection this
49 /// will create reflection, in case of C++20 reflection this will create a member pointer
50 #define Member(x) &x
51#endif
52
53#if !defined(_WIN32)
54 #define WTEXT(x) (u##x)
55#else
56 #define WTEXT(x) (L##x)
57#endif
58
59#define UNSUPPORTED_DATABASE(stmt, dbType) \
60 if ((stmt).Connection().ServerType() == (dbType)) \
61 { \
62 WARN(std::format("TODO({}): This database is currently unsupported on this test.", dbType)); \
63 return; \
64 }
65
66template <>
67struct std::formatter<std::u8string>: std::formatter<std::string>
68{
69 auto format(std::u8string const& value, std::format_context& ctx) const -> std::format_context::iterator
70 {
71 return std::formatter<std::string>::format(
72 std::format("{}", (char const*) value.c_str()), // NOLINT(readability-redundant-string-cstr)
73 ctx);
74 }
75};
76
77namespace std
78{
79
80// Add support for std::basic_string<WideChar> and std::basic_string_view<WideChar> to std::ostream,
81// so that we can get them pretty-printed in REQUIRE() and CHECK() macros.
82
83template <typename WideStringT>
84 requires(Lightweight::detail::OneOf<WideStringT,
85 std::wstring,
86 std::wstring_view,
87 std::u16string,
88 std::u16string_view,
89 std::u32string,
90 std::u32string_view>)
91ostream& operator<<(ostream& os, WideStringT const& str)
92{
93 auto constexpr BitsPerChar = sizeof(typename WideStringT::value_type) * 8;
94 auto const u8String = Lightweight::ToUtf8(str);
95 return os << "UTF-" << BitsPerChar << '{' << "length: " << str.size() << ", characters: " << '"'
96 << string_view((char const*) u8String.data(), u8String.size()) << '"' << '}';
97}
98
99inline ostream& operator<<(ostream& os, Lightweight::SqlGuid const& guid)
100{
101 return os << format("SqlGuid({})", guid);
102}
103
104inline std::string EscapedBinaryText(std::string_view binary)
105{
106 std::string hexEncodedString;
107 for (auto const& b: binary)
108 {
109 if (std::isprint(b))
110 hexEncodedString += static_cast<char>(b);
111 else
112 hexEncodedString += std::format("\\x{:02x}", static_cast<unsigned char>(b));
113 }
114 return hexEncodedString;
115}
116
117template <size_t N>
118inline ostream& operator<<(ostream& os, Lightweight::SqlDynamicBinary<N> const& binary)
119{
120 auto const hexEncodedString = EscapedBinaryText(std::string_view((char const*) binary.data(), binary.size()));
121 return os << std::format("SqlDynamicBinary<{}>(length: {}, characters: {})", N, binary.size(), hexEncodedString);
122}
123
124} // namespace std
125
126namespace std
127{
128template <std::size_t Precision, std::size_t Scale>
129std::ostream& operator<<(std::ostream& os, Lightweight::SqlNumeric<Precision, Scale> const& value)
130{
131 return os << std::format("SqlNumeric<{}, {}>({}, {}, {}, {})",
132 Precision,
133 Scale,
134 value.sqlValue.sign,
135 value.sqlValue.precision,
136 value.sqlValue.scale,
137 value.ToUnscaledValue());
138}
139} // namespace std
140
141// Refer to an in-memory SQLite database (and assuming the sqliteodbc driver is installed)
142// See:
143// - https://www.sqlite.org/inmemorydb.html
144// - http://www.ch-werner.de/sqliteodbc/
145// - https://github.com/softace/sqliteodbc
146//
147
148// clang-format off
149auto inline const DefaultTestConnectionString = Lightweight::SqlConnectionString { //NOLINT(bugprone-throwing-static-initialization)
150 // clang-format on
151 .value = std::format("DRIVER={};Database={}",
152#if defined(_WIN32) || defined(_WIN64)
153 "SQLite3 ODBC Driver",
154#else
155 "SQLite3",
156#endif
157 "test.db"),
158};
159
160class TestSuiteSqlLogger: public Lightweight::SqlLogger::Null
161{
162 private:
163 mutable std::mutex m_mutex;
164 std::string m_lastPreparedQuery;
165
166 void WriteRawInfo(std::string_view message);
167
168 template <typename... Args>
169 void WriteInfo(std::format_string<Args...> const& fmt, Args&&... args)
170 {
171 auto message = std::format(fmt, std::forward<Args>(args)...);
172 message = std::format("[{}] {}", "Lightweight", message);
173 WriteRawInfo(message);
174 }
175
176 template <typename... Args>
177 void WriteWarning(std::format_string<Args...> const& fmt, Args&&... args)
178 {
179 WARN(std::format(fmt, std::forward<Args>(args)...));
180 }
181
182 public:
183 static TestSuiteSqlLogger& GetLogger() noexcept
184 {
185 static TestSuiteSqlLogger theLogger;
186 return theLogger;
187 }
188
189 void OnError(Lightweight::SqlError error, std::source_location sourceLocation) override
190 {
191 WriteWarning("SQL Error: {}", error);
192 WriteDetails(sourceLocation);
193 }
194
195 void OnError(Lightweight::SqlErrorInfo const& errorInfo, std::source_location sourceLocation) override
196 {
197 WriteWarning("SQL Error: {}", errorInfo);
198 WriteDetails(sourceLocation);
199 }
200
201 void OnWarning(std::string_view const& message) override
202 {
203 WriteWarning("{}", message);
204 WriteDetails(std::source_location::current());
205 }
206
207 void OnExecuteDirect(std::string_view const& query) override
208 {
209 WriteInfo("ExecuteDirect: {}", query);
210 }
211
212 void OnPrepare(std::string_view const& query) override
213 {
214 std::scoped_lock lock(m_mutex);
215 m_lastPreparedQuery = query;
216 }
217
218 void OnExecute(std::string_view const& query) override
219 {
220 WriteInfo("Execute: {}", query);
221 }
222
223 void OnExecuteBatch() override
224 {
225 std::scoped_lock lock(m_mutex);
226 WriteInfo("ExecuteBatch: {}", m_lastPreparedQuery);
227 }
228
229 void OnFetchRow() override
230 {
231 WriteInfo("Fetched row");
232 }
233
234 void OnFetchEnd() override
235 {
236 // WriteInfo("Fetch end");
237 }
238
239 private:
240 void WriteDetails(std::source_location sourceLocation)
241 {
242 std::scoped_lock lock(m_mutex);
243 WriteInfo(" Source: {}:{}", sourceLocation.file_name(), sourceLocation.line());
244 if (!m_lastPreparedQuery.empty())
245 WriteInfo(" Query: {}", m_lastPreparedQuery);
246 WriteInfo(" Stack trace:");
247
248#if __has_include(<stacktrace>)
249 auto stackTrace = std::stacktrace::current(1, 25);
250 for (auto const& entry: stackTrace)
251 WriteInfo(" {}", entry);
252#endif
253 }
254};
255
256// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
257class ScopedSqlNullLogger: public Lightweight::SqlLogger::Null
258{
259 private:
260 SqlLogger& m_previousLogger = SqlLogger::GetLogger();
261
262 public:
263 ScopedSqlNullLogger()
264 {
265 SqlLogger::SetLogger(*this);
266 }
267
268 ~ScopedSqlNullLogger() override
269 {
270 SqlLogger::SetLogger(m_previousLogger);
271 }
272};
273
274template <typename Getter, typename Callable>
275constexpr void FixedPointIterate(Getter const& getter, Callable const& callable)
276{
277 auto a = getter();
278 for (;;)
279 {
280 callable(a);
281 auto b = getter();
282 if (a == b)
283 break;
284 a = std::move(b);
285 }
286}
287
288// Searches upward from current working directory for .test-env.yml
289// Stops at directory containing .git (project root boundary) or filesystem root
290inline std::optional<std::filesystem::path> FindTestEnvFile()
291{
292 auto currentDir = std::filesystem::current_path();
293
294 while (true)
295 {
296 auto testEnvPath = currentDir / ".test-env.yml";
297 if (std::filesystem::exists(testEnvPath))
298 return testEnvPath;
299
300 // Stop at project root (directory containing .git) or filesystem root
301 auto gitPath = currentDir / ".git";
302 if (std::filesystem::exists(gitPath))
303 return std::nullopt; // Reached project root, file not found
304
305 auto parentDir = currentDir.parent_path();
306 if (parentDir == currentDir)
307 return std::nullopt; // Reached filesystem root
308
309 currentDir = parentDir;
310 }
311}
312
313// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
314class SqlTestFixture
315{
316 public:
317 static inline std::string testDatabaseName = "LightweightTest"; // NOLINT(bugprone-throwing-static-initialization)
318 static inline bool odbcTrace = false;
319 static inline std::atomic<bool> running = false;
320
321 using MainProgramArgs = std::tuple<int, char**>;
322
323 // NOLINTNEXTLINE(readability-function-cognitive-complexity)
324 static std::variant<MainProgramArgs, int> Initialize(int argc, char** argv)
325 {
326 Lightweight::SqlLogger::SetLogger(TestSuiteSqlLogger::GetLogger());
327
328 using namespace std::string_view_literals;
329 std::optional<std::string> testEnvName;
330 int i = 1;
331 for (; i < argc; ++i)
332 {
333 if (argv[i] == "--trace-sql"sv)
335 else if (argv[i] == "--trace-odbc"sv)
336 odbcTrace = true;
337 else if (std::string_view(argv[i]).starts_with("--test-env="))
338 testEnvName = std::string_view(argv[i]).substr(11);
339 else if (argv[i] == "--help"sv || argv[i] == "-h"sv)
340 {
341 std::println("{} [--test-env=NAME] [--trace-sql] [--trace-odbc] [[--] [Catch2 flags ...]]", argv[0]);
342 std::println("");
343 std::println("Options:");
344 std::println(" --test-env=NAME Use connection string from .test-env.yml (e.g., pgsql, mssql, sqlite)");
345 std::println(" --trace-sql Enable SQL tracing");
346 std::println(" --trace-odbc Enable ODBC tracing");
347 return { EXIT_SUCCESS };
348 }
349 else if (argv[i] == "--"sv)
350 {
351 ++i;
352 break;
353 }
354 else
355 break;
356 }
357
358 if (i < argc)
359 argv[i - 1] = argv[0];
360
361 // Connection string priority:
362 // 1. --test-env=NAME (must find valid entry, error if not found)
363 // 2. Environment variable ODBC_CONNECTION_STRING
364 // 3. Default SQLite3
365
366 if (testEnvName)
367 {
368 auto configPath = FindTestEnvFile();
369 if (!configPath)
370 {
371 std::println(stderr,
372 "Error: .test-env.yml not found (searched from '{}' to project root)",
373 std::filesystem::current_path().string());
374 return { EXIT_FAILURE };
375 }
376
377 try
378 {
379 YAML::Node config = YAML::LoadFile(configPath->string());
380 auto connectionStrings = config["ODBC_CONNECTION_STRING"];
381 if (!connectionStrings || !connectionStrings[*testEnvName])
382 {
383 std::println(stderr, "Error: Key '{}' not found in ODBC_CONNECTION_STRING", *testEnvName);
384 if (connectionStrings && connectionStrings.IsMap())
385 {
386 std::print(stderr, "Available environments:");
387 for (auto const& entry: connectionStrings)
388 std::print(stderr, " {}", entry.first.as<std::string>());
389 std::println(stderr, "");
390 }
391 return { EXIT_FAILURE };
392 }
393 auto connStr = connectionStrings[*testEnvName].as<std::string>();
394 if (connStr.empty())
395 {
396 std::println(stderr, "Error: Connection string for '{}' is empty", *testEnvName);
397 return { EXIT_FAILURE };
398 }
399 std::println("Using test environment '{}' from: {}", *testEnvName, configPath->string());
400 std::println("Using ODBC connection string: '{}'", Lightweight::SqlConnectionString::SanitizePwd(connStr));
402 }
403 catch (YAML::Exception const& e)
404 {
405 std::println(stderr, "Error parsing {}: {}", configPath->string(), e.what());
406 return { EXIT_FAILURE };
407 }
408 }
409 else
410 {
411#if defined(_MSC_VER)
412 char* envBuffer = nullptr;
413 size_t envBufferLen = 0;
414 _dupenv_s(&envBuffer, &envBufferLen, "ODBC_CONNECTION_STRING");
415 if (auto const* s = envBuffer; s && *s)
416#else
417 if (auto const* s = std::getenv("ODBC_CONNECTION_STRING"); s && *s)
418#endif
419
420 {
421 std::println("Using ODBC connection string: '{}'", Lightweight::SqlConnectionString::SanitizePwd(s));
423 }
424 else
425 {
426 // Use an in-memory SQLite3 database by default (for testing purposes)
427 std::println("Using default ODBC connection string: '{}'", DefaultTestConnectionString.value);
429 }
430 }
431
432 Lightweight::SqlConnection::SetPostConnectedHook(&SqlTestFixture::PostConnectedHook);
433
434 auto sqlConnection = Lightweight::SqlConnection();
435 if (!sqlConnection.IsAlive())
436 {
437 std::println("Failed to connect to the database: {}", sqlConnection.LastError());
438 std::abort();
439 }
440
441 std::println("Running test cases against: {} ({}) (identified as: {})",
442 sqlConnection.ServerName(),
443 sqlConnection.ServerVersion(),
444 sqlConnection.ServerType());
445
446 return MainProgramArgs { argc - (i - 1), argv + (i - 1) };
447 }
448
449 static void PostConnectedHook(Lightweight::SqlConnection& connection)
450 {
451 if (odbcTrace)
452 {
453 auto const traceFile = []() -> std::string_view {
454#if !defined(_WIN32) && !defined(_WIN64)
455 return "/dev/stdout";
456#else
457 return "CONOUT$";
458#endif
459 }();
460
461 SQLHDBC handle = connection.NativeHandle();
462 SQLSetConnectAttrA(handle, SQL_ATTR_TRACEFILE, (SQLPOINTER) traceFile.data(), SQL_NTS);
463 SQLSetConnectAttrA(handle, SQL_ATTR_TRACE, (SQLPOINTER) SQL_OPT_TRACE_ON, SQL_IS_UINTEGER);
464 }
465
466 using Lightweight::SqlServerType;
467 switch (connection.ServerType())
468 {
469 case SqlServerType::SQLITE: {
470 auto stmt = Lightweight::SqlStatement { connection };
471 // Enable foreign key constraints for SQLite
472 stmt.ExecuteDirect("PRAGMA foreign_keys = ON");
473 break;
474 }
475 case SqlServerType::MICROSOFT_SQL:
476 case SqlServerType::POSTGRESQL:
477 case SqlServerType::MYSQL:
478 case SqlServerType::UNKNOWN:
479 break;
480 }
481 }
482
483 SqlTestFixture()
484 {
485 running = true;
486 auto stmt = Lightweight::SqlStatement();
487 REQUIRE(stmt.IsAlive());
488
489 char dbName[100]; // Buffer to store the database name
490 SQLSMALLINT dbNameLen {};
491 SQLGetInfo(stmt.Connection().NativeHandle(), SQL_DATABASE_NAME, dbName, sizeof(dbName), &dbNameLen);
492 if (dbNameLen > 0)
493 testDatabaseName = dbName;
494
495 DropAllTablesInDatabase(stmt);
496 }
497
498 virtual ~SqlTestFixture()
499 {
500 running = false;
502 }
503
504 static std::string ToString(std::vector<std::string> const& values, std::string_view separator)
505 {
506 auto result = std::string {};
507 for (auto const& value: values)
508 {
509 if (!result.empty())
510 result += separator;
511 result += value;
512 }
513 return result;
514 }
515
516 static void DropTableRecursively(Lightweight::SqlStatement& stmt,
517 Lightweight::SqlSchema::FullyQualifiedTableName const& table)
518 {
519 auto const dependantTables = Lightweight::SqlSchema::AllForeignKeysTo(stmt, table);
520 for (auto const& dependantTable: dependantTables)
521 DropTableRecursively(stmt, dependantTable.foreignKey.table);
522 stmt.ExecuteDirect(std::format("DROP TABLE IF EXISTS \"{}\"", table.table));
523 }
524
525 static void DropTableIfExists(Lightweight::SqlConnection& conn, std::string const& tableName)
526 {
527 Lightweight::SqlStatement stmt { conn };
528 try
529 {
530 stmt.ExecuteDirect(std::format("DROP TABLE IF EXISTS {}", tableName));
531 }
532 catch (...)
533 {
534 ; // ignore
535 }
536 }
537
538 static void DropAllTablesInDatabase(Lightweight::SqlStatement& stmt)
539 {
540 using Lightweight::SqlServerType;
541 switch (stmt.Connection().ServerType())
542 {
543 case SqlServerType::MICROSOFT_SQL:
544 case SqlServerType::MYSQL:
545 stmt.ExecuteDirect(std::format("USE \"{}\"", testDatabaseName));
546 [[fallthrough]];
547 case SqlServerType::SQLITE:
548 case SqlServerType::UNKNOWN: {
549 auto const tableNames = GetAllTableNames(stmt);
550 for (auto const& tableName: tableNames)
551 {
552 if (tableName == "sqlite_sequence")
553 continue;
554
555 DropTableRecursively(stmt,
556 Lightweight::SqlSchema::FullyQualifiedTableName {
557 .catalog = {},
558 .schema = {},
559 .table = tableName,
560 });
561 }
562 break;
563 }
564 case SqlServerType::POSTGRESQL:
565 if (m_createdTables.empty())
566 m_createdTables = GetAllTableNames(stmt);
567 for (auto& createdTable: std::views::reverse(m_createdTables))
568 stmt.ExecuteDirect(std::format("DROP TABLE IF EXISTS \"{}\" CASCADE", createdTable));
569 break;
570 }
571 m_createdTables.clear();
572 }
573
574 static std::string GetDefaultSchemaName(Lightweight::SqlConnection const& connection)
575 {
576 using namespace std::string_literals;
577 switch (connection.ServerType())
578 {
579 case Lightweight::SqlServerType::MICROSOFT_SQL:
580 return "dbo"s;
581 default:
582 return ""s;
583 }
584 }
585
586 private:
587 static std::vector<std::string> GetAllTableNames(Lightweight::SqlStatement& stmt)
588 {
589 using namespace std::string_literals;
590 auto result = std::vector<std::string>();
591 auto const schemaName = GetDefaultSchemaName(stmt.Connection());
592 auto const sqlResult = SQLTables(stmt.NativeHandle(),
593 (SQLCHAR*) testDatabaseName.data(),
594 (SQLSMALLINT) testDatabaseName.size(),
595 (SQLCHAR*) schemaName.data(),
596 (SQLSMALLINT) schemaName.size(),
597 nullptr,
598 0,
599 (SQLCHAR*) "TABLE",
600 SQL_NTS);
601 if (SQL_SUCCEEDED(sqlResult))
602 {
603 while (stmt.FetchRow())
604 {
605 result.emplace_back(stmt.GetColumn<std::string>(3)); // table name
606 }
607 }
608 return result;
609 }
610
611 static inline std::vector<std::string> m_createdTables;
612};
613
614// {{{ ostream support for Lightweight, for debugging purposes
615inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlText const& value)
616{
617 return os << std::format("SqlText({})", value.value);
618}
619
620inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlDate const& date)
621{
622 auto const ymd = date.value();
623 return os << std::format("SqlDate {{ {}-{}-{} }}", ymd.year(), ymd.month(), ymd.day());
624}
625
626inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlTime const& time)
627{
628 auto const value = time.value();
629 return os << std::format("SqlTime {{ {:02}:{:02}:{:02}.{:06} }}",
630 value.hours().count(),
631 value.minutes().count(),
632 value.seconds().count(),
633 value.subseconds().count());
634}
635
636inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlDateTime const& datetime)
637{
638 auto const value = datetime.value();
639 auto const totalDays = std::chrono::floor<std::chrono::days>(value);
640 auto const ymd = std::chrono::year_month_day { totalDays };
641 auto const hms =
642 std::chrono::hh_mm_ss<std::chrono::nanoseconds> { std::chrono::floor<std::chrono::nanoseconds>(value - totalDays) };
643 return os << std::format("SqlDateTime {{ {:04}-{:02}-{:02} {:02}:{:02}:{:02}.{:09} }}",
644 (int) ymd.year(),
645 (unsigned) ymd.month(),
646 (unsigned) ymd.day(),
647 hms.hours().count(),
648 hms.minutes().count(),
649 hms.seconds().count(),
650 hms.subseconds().count());
651}
652
653template <std::size_t N, typename T, Lightweight::SqlFixedStringMode Mode>
654inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlFixedString<N, T, Mode> const& value)
655{
656 if constexpr (Mode == Lightweight::SqlFixedStringMode::FIXED_SIZE)
657 return os << std::format("SqlFixedString<{}> {{ size: {}, data: '{}' }}", N, value.size(), value.data());
658 else if constexpr (Mode == Lightweight::SqlFixedStringMode::FIXED_SIZE_RIGHT_TRIMMED)
659 return os << std::format("SqlTrimmedFixedString<{}> {{ '{}' }}", N, value.data());
660 else if constexpr (Mode == Lightweight::SqlFixedStringMode::VARIABLE_SIZE)
661 {
662 if constexpr (std::same_as<T, char>)
663 return os << std::format("SqlVariableString<{}> {{ size: {}, '{}' }}", N, value.size(), value.data());
664 else
665 {
666 auto u8String = ToUtf8(std::basic_string_view<T>(value.data(), value.size()));
667 return os << std::format("SqlVariableString<{}, {}> {{ size: {}, '{}' }}",
668 N,
669 Reflection::TypeNameOf<T>,
670 value.size(),
671 (char const*) u8String.c_str()); // NOLINT(readability-redundant-string-cstr)
672 }
673 }
674 else
675 return os << std::format("SqlFixedString<{}> {{ size: {}, data: '{}' }}", N, value.size(), value.data());
676}
677
678template <std::size_t N, typename T>
679inline std::ostream& operator<<(std::ostream& os, Lightweight::SqlDynamicString<N, T> const& value)
680{
681 if constexpr (std::same_as<T, char>)
682 return os << std::format("SqlDynamicString<{}> {{ size: {}, '{}' }}", N, value.size(), value.data());
683 else
684 {
685 auto u8String = ToUtf8(std::basic_string_view<T>(value.data(), value.size()));
686 return os << std::format("SqlDynamicString<{}, {}> {{ size: {}, '{}' }}",
687 N,
688 Reflection::TypeNameOf<T>,
689 value.size(),
690 (char const*) u8String.c_str()); // NOLINT(readability-redundant-string-cstr)
691 }
692}
693
694[[nodiscard]] inline std::string NormalizeText(std::string_view const& text)
695{
696 auto result = std::string(text);
697
698 // Remove any newlines and reduce all whitespace to a single space
699 result.erase(std::unique(result.begin(), // NOLINT(modernize-use-ranges)
700 result.end(),
701 [](char a, char b) { return std::isspace(a) && std::isspace(b); }),
702 result.end());
703
704 // trim lading and trailing whitespace
705 while (!result.empty() && std::isspace(result.front()))
706 result.erase(result.begin());
707
708 while (!result.empty() && std::isspace(result.back()))
709 result.pop_back();
710
711 return result;
712}
713
714[[nodiscard]] inline std::string NormalizeText(std::vector<std::string> const& texts)
715{
716 auto result = std::string {};
717 for (auto const& text: texts)
718 {
719 if (!result.empty())
720 result += '\n';
721 result += NormalizeText(text);
722 }
723 return result;
724}
725
726// }}}
727
728/// Enables foreign keys for SQLite (no-op for other DBMS).
729inline void EnableForeignKeysIfNeeded(Lightweight::SqlConnection& conn)
730{
731 if (conn.ServerType() == Lightweight::SqlServerType::SQLITE)
732 {
733 Lightweight::SqlStatement stmt { conn };
734 stmt.ExecuteDirect("PRAGMA foreign_keys = ON");
735 }
736}
737
738/// Disables foreign keys for SQLite (no-op for other DBMS).
739inline void DisableForeignKeysIfNeeded(Lightweight::SqlConnection& conn)
740{
741 if (conn.ServerType() == Lightweight::SqlServerType::SQLITE)
742 {
743 Lightweight::SqlStatement stmt { conn };
744 stmt.ExecuteDirect("PRAGMA foreign_keys = OFF");
745 }
746}
747
748/// Wraps a function with IDENTITY_INSERT ON/OFF for MS SQL.
749/// For other DBMS, the function is called directly.
750template <typename Func>
751inline void WithIdentityInsert(Lightweight::SqlStatement& stmt, std::string_view tableName, Func&& func)
752{
753 if (stmt.Connection().ServerType() == Lightweight::SqlServerType::MICROSOFT_SQL)
754 {
755 stmt.ExecuteDirect(std::format("SET IDENTITY_INSERT \"{}\" ON", tableName));
756 try
757 {
758 std::forward<Func>(func)();
759 stmt.ExecuteDirect(std::format("SET IDENTITY_INSERT \"{}\" OFF", tableName));
760 }
761 catch (...)
762 {
763 stmt.ExecuteDirect(std::format("SET IDENTITY_INSERT \"{}\" OFF", tableName));
764 throw;
765 }
766 }
767 else
768 {
769 std::forward<Func>(func)();
770 }
771}
772
773inline void CreateEmployeesTable(Lightweight::SqlStatement& stmt,
774 std::source_location location = std::source_location::current())
775{
776 stmt.MigrateDirect(
778 migration.CreateTable("Employees")
779 .PrimaryKeyWithAutoIncrement("EmployeeID")
780 .RequiredColumn("FirstName", Lightweight::SqlColumnTypeDefinitions::Varchar { 50 })
781 .Column("LastName", Lightweight::SqlColumnTypeDefinitions::Varchar { 50 })
782 .RequiredColumn("Salary", Lightweight::SqlColumnTypeDefinitions::Integer {});
783 },
784 location);
785}
786
787inline void CreateLargeTable(Lightweight::SqlStatement& stmt)
788{
790 auto table = migration.CreateTable("LargeTable");
791 for (char c = 'A'; c <= 'Z'; ++c)
792 {
793 table.Column(std::string(1, c), Lightweight::SqlColumnTypeDefinitions::Varchar { 50 });
794 }
795 });
796}
797
798inline void FillEmployeesTable(Lightweight::SqlStatement& stmt)
799{
800 stmt.Prepare(stmt.Query("Employees")
801 .Insert()
802 .Set("FirstName", Lightweight::SqlWildcard)
803 .Set("LastName", Lightweight::SqlWildcard)
804 .Set("Salary", Lightweight::SqlWildcard));
805 stmt.Execute("Alice", "Smith", 50'000);
806 stmt.Execute("Bob", "Johnson", 60'000);
807 stmt.Execute("Charlie", "Brown", 70'000);
808}
809
810template <typename T = char>
811inline auto MakeLargeText(size_t size)
812{
813 auto text = std::basic_string<T>(size, {});
814 std::ranges::generate(text, [i = 0]() mutable { return static_cast<T>('A' + (i++ % 26)); });
815 return text;
816}
817
818inline bool IsGithubActions()
819{
820#if defined(_WIN32) || defined(_WIN64)
821 char envBuffer[32] {};
822 size_t requiredCount = 0;
823 return getenv_s(&requiredCount, envBuffer, sizeof(envBuffer), "GITHUB_ACTIONS") == 0
824 && std::string_view(envBuffer) == "true" == 0;
825#else
826 return std::getenv("GITHUB_ACTIONS") != nullptr
827 && std::string_view(std::getenv("GITHUB_ACTIONS")) == "true"; // NOLINT(clang-analyzer-core.NonNullParamChecker)
828#endif
829}
830
831namespace std
832{
833
834template <typename T>
835ostream& operator<<(ostream& os, optional<T> const& opt)
836{
837 if (opt.has_value())
838 return os << *opt;
839 else
840 return os << "nullopt";
841}
842
843} // namespace std
844
845template <typename T, auto P1, auto P2>
846std::ostream& operator<<(std::ostream& os, Lightweight::Field<std::optional<T>, P1, P2> const& field)
847{
848 if (field.Value())
849 return os << std::format("Field<{}> {{ {}, {} }}",
850 Reflection::TypeNameOf<T>,
851 *field.Value(),
852 field.IsModified() ? "modified" : "not modified");
853 else
854 return os << "NULL";
855}
856
857template <typename T, auto P1, auto P2>
858std::ostream& operator<<(std::ostream& os, Lightweight::Field<T, P1, P2> const& field)
859{
860 return os << std::format("Field<{}> {{ ", Reflection::TypeNameOf<T>) << "value: " << field.Value() << "; "
861 << (field.IsModified() ? "modified" : "not modified") << " }";
862}
Represents a connection to a SQL database.
SqlServerType ServerType() const noexcept
Retrieves the type of the server.
static LIGHTWEIGHT_API void SetDefaultConnectionString(SqlConnectionString const &connectionString) noexcept
SQLHDBC NativeHandle() const noexcept
Retrieves the native handle.
static LIGHTWEIGHT_API void SetPostConnectedHook(std::function< void(SqlConnection &)> hook)
Sets a callback to be called after each connection being established.
LIGHTWEIGHT_API SqlCreateTableQueryBuilder & Column(SqlColumnDeclaration column)
Adds a new column to the table.
LIGHTWEIGHT_FORCE_INLINE constexpr decltype(auto) data(this auto &&self) noexcept
Retrieves the pointer to the string data.
LIGHTWEIGHT_FORCE_INLINE constexpr std::size_t size() const noexcept
Retrieves the size of the string.
LIGHTWEIGHT_FORCE_INLINE std::size_t size() const noexcept
Retrieves the string's size.
LIGHTWEIGHT_FORCE_INLINE T const * data() const noexcept
Retrieves the string's inner value (as T const*).
LIGHTWEIGHT_FORCE_INLINE constexpr std::size_t size() const noexcept
Returns the size of the string.
static LIGHTWEIGHT_API SqlLogger & TraceLogger()
Retrieves a logger that logs to the trace logger.
static LIGHTWEIGHT_API void SetLogger(SqlLogger &logger)
static LIGHTWEIGHT_API SqlLogger & StandardLogger()
Retrieves a logger that logs to standard output.
Query builder for building SQL migration queries.
Definition Migrate.hpp:331
LIGHTWEIGHT_API SqlCreateTableQueryBuilder CreateTable(std::string_view tableName)
Creates a new table.
LIGHTWEIGHT_API SqlInsertQueryBuilder Insert(std::vector< SqlVariant > *boundInputs=nullptr) noexcept
High level API for (prepared) raw SQL statements.
LIGHTWEIGHT_API void Prepare(std::string_view query) &
LIGHTWEIGHT_API SQLHSTMT NativeHandle() const noexcept
Retrieves the native handle of the statement.
void MigrateDirect(Callable const &callable, std::source_location location=std::source_location::current())
Executes an SQL migration query, as created b the callback.
LIGHTWEIGHT_API SqlConnection & Connection() noexcept
Retrieves the connection associated with this statement.
void Execute(Args const &... args)
Binds the given arguments to the prepared statement and executes it.
bool GetColumn(SQLUSMALLINT column, T *result) const
LIGHTWEIGHT_API SqlQueryBuilder Query(std::string_view const &table={}) const
Creates a new query builder for the given table, compatible with the SQL server being connected.
LIGHTWEIGHT_API void ExecuteDirect(std::string_view const &query, std::source_location location=std::source_location::current())
Executes the given query directly.
LIGHTWEIGHT_API bool FetchRow()
LIGHTWEIGHT_API std::u8string ToUtf8(std::u32string_view u32InputString)
Represents a single column in a table.
Definition Field.hpp:84
constexpr bool IsModified() const noexcept
Checks if the field has been modified.
Definition Field.hpp:311
constexpr T const & Value() const noexcept
Returns the value of the field.
Definition Field.hpp:317
Represents an ODBC connection string.
constexpr LIGHTWEIGHT_FORCE_INLINE native_type value() const noexcept
Returns the current date and time.
LIGHTWEIGHT_FORCE_INLINE constexpr std::chrono::year_month_day value() const noexcept
Returns the current date.
Definition SqlDate.hpp:29
Represents an ODBC SQL error.
Definition SqlError.hpp:33
constexpr LIGHTWEIGHT_FORCE_INLINE auto ToUnscaledValue() const noexcept
Converts the numeric to an unscaled integer value.