From 9996bd3f89490e4e298608874633b5f41b954558 Mon Sep 17 00:00:00 2001 From: Nils Mueller Date: Thu, 26 Feb 2026 21:26:35 +0200 Subject: [PATCH] add transparent RDS IAM backend auth --- Cargo.lock | 629 ++++++++++++++++++++++- README.md | 19 + example.pgdog.toml | 4 + example.users.toml | 6 + pgdog-config/src/core.rs | 84 ++- pgdog-config/src/lib.rs | 2 +- pgdog-config/src/users.rs | 56 ++ pgdog/Cargo.toml | 2 + pgdog/src/backend/auth/mod.rs | 1 + pgdog/src/backend/auth/rds_iam.rs | 170 ++++++ pgdog/src/backend/error.rs | 3 + pgdog/src/backend/mod.rs | 1 + pgdog/src/backend/pool/address.rs | 84 ++- pgdog/src/backend/schema/sync/pg_dump.rs | 84 ++- pgdog/src/backend/server.rs | 84 ++- pgdog/src/config/mod.rs | 4 +- 16 files changed, 1173 insertions(+), 60 deletions(-) create mode 100644 pgdog/src/backend/auth/mod.rs create mode 100644 pgdog/src/backend/auth/rds_iam.rs diff --git a/Cargo.lock b/Cargo.lock index f2b9d965c..ec18a9c5a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -167,6 +167,48 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "aws-config" +version = "1.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c478f5b10ce55c9a33f87ca3404ca92768b144fc1bfdede7c0121214a8283a25" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-sdk-sso", + "aws-sdk-ssooidc", + "aws-sdk-sts", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "hex", + "http 1.3.1", + "ring 0.17.14", + "time", + "tokio", + "tracing", + "url", + "zeroize", +] + +[[package]] +name = "aws-credential-types" +version = "1.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cd362783681b15d136480ad555a099e82ecd8e2d10a841e14dfd0078d67fee3" +dependencies = [ + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "zeroize", +] + [[package]] name = "aws-lc-rs" version = "1.13.1" @@ -190,6 +232,323 @@ dependencies = [ "fs_extra", ] +[[package]] +name = "aws-runtime" +version = "1.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c034a1bc1d70e16e7f4e4caf7e9f7693e4c9c24cd91cf17c2a0b21abaebc7c8b" +dependencies = [ + "aws-credential-types", + "aws-sigv4", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "http-body 0.4.6", + "percent-encoding", + "pin-project-lite", + "tracing", + "uuid", +] + +[[package]] +name = "aws-sdk-rds" +version = "1.104.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b01fb9c0ce98222a36bb15679a67e303ec0dd608c7da91af6f2acaa6c177cee" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-sigv4", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-query", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-smithy-xml", + "aws-types", + "fastrand", + "http 0.2.12", + "regex-lite", + "tracing", + "url", +] + +[[package]] +name = "aws-sdk-sso" +version = "1.82.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b069e4973dc25875bbd54e4c6658bdb4086a846ee9ed50f328d4d4c33ebf9857" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-ssooidc" +version = "1.83.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b49e8fe57ff100a2f717abfa65bdd94e39702fa5ab3f60cddc6ac7784010c68" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-sts" +version = "1.84.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91abcdbfb48c38a0419eb75e0eac772a4783a96750392680e4f3c25a8a0535b9" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-query", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-smithy-xml", + "aws-types", + "fastrand", + "http 0.2.12", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sigv4" +version = "1.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69e523e1c4e8e7e8ff219d732988e22bfeae8a1cafdbe6d9eca1546fa080be7c" +dependencies = [ + "aws-credential-types", + "aws-smithy-http", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "form_urlencoded", + "hex", + "hmac", + "http 0.2.12", + "http 1.3.1", + "percent-encoding", + "sha2", + "time", + "tracing", +] + +[[package]] +name = "aws-smithy-async" +version = "1.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ee19095c7c4dda59f1697d028ce704c24b2d33c6718790c7f1d5a3015b4107c" +dependencies = [ + "futures-util", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "aws-smithy-http" +version = "0.62.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "826141069295752372f8203c17f28e30c464d22899a43a0c9fd9c458d469c88b" +dependencies = [ + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "bytes-utils", + "futures-core", + "futures-util", + "http 0.2.12", + "http 1.3.1", + "http-body 0.4.6", + "percent-encoding", + "pin-project-lite", + "pin-utils", + "tracing", +] + +[[package]] +name = "aws-smithy-http-client" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f108f1ca850f3feef3009bdcc977be201bca9a91058864d9de0684e64514bee0" +dependencies = [ + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "h2 0.3.27", + "h2 0.4.10", + "http 0.2.12", + "http 1.3.1", + "http-body 0.4.6", + "hyper 0.14.32", + "hyper 1.6.0", + "hyper-rustls 0.24.2", + "hyper-rustls 0.27.7", + "hyper-util", + "pin-project-lite", + "rustls 0.21.12", + "rustls 0.23.27", + "rustls-native-certs 0.8.1", + "rustls-pki-types", + "tokio", + "tower", + "tracing", +] + +[[package]] +name = "aws-smithy-json" +version = "0.61.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49fa1213db31ac95288d981476f78d05d9cbb0353d22cdf3472cc05bb02f6551" +dependencies = [ + "aws-smithy-types", +] + +[[package]] +name = "aws-smithy-observability" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17f616c3f2260612fe44cede278bafa18e73e6479c4e393e2c4518cf2a9a228a" +dependencies = [ + "aws-smithy-runtime-api", +] + +[[package]] +name = "aws-smithy-query" +version = "0.60.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae5d689cf437eae90460e944a58b5668530d433b4ff85789e69d2f2a556e057d" +dependencies = [ + "aws-smithy-types", + "urlencoding", +] + +[[package]] +name = "aws-smithy-runtime" +version = "1.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e107ce0783019dbff59b3a244aa0c114e4a8c9d93498af9162608cd5474e796" +dependencies = [ + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-http-client", + "aws-smithy-observability", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "fastrand", + "http 0.2.12", + "http 1.3.1", + "http-body 0.4.6", + "http-body 1.0.1", + "pin-project-lite", + "pin-utils", + "tokio", + "tracing", +] + +[[package]] +name = "aws-smithy-runtime-api" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efce7aaaf59ad53c5412f14fc19b2d5c6ab2c3ec688d272fd31f76ec12f44fb0" +dependencies = [ + "aws-smithy-async", + "aws-smithy-types", + "bytes", + "http 0.2.12", + "http 1.3.1", + "pin-project-lite", + "tokio", + "tracing", + "zeroize", +] + +[[package]] +name = "aws-smithy-types" +version = "1.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65f172bcb02424eb94425db8aed1b6d583b5104d4d5ddddf22402c661a320048" +dependencies = [ + "base64-simd", + "bytes", + "bytes-utils", + "futures-core", + "http 0.2.12", + "http 1.3.1", + "http-body 0.4.6", + "http-body 1.0.1", + "http-body-util", + "itoa", + "num-integer", + "pin-project-lite", + "pin-utils", + "ryu", + "serde", + "time", + "tokio", + "tokio-util", +] + +[[package]] +name = "aws-smithy-xml" +version = "0.60.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b53543b4b86ed43f051644f704a98c7291b3618b67adf057ee77a366fa52fcaa" +dependencies = [ + "xmlparser", +] + +[[package]] +name = "aws-types" +version = "1.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d980627d2dd7bfc32a3c025685a033eeab8d365cc840c631ef59d1b8f428164" +dependencies = [ + "aws-credential-types", + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "rustc_version", + "tracing", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -211,12 +570,28 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64" version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64-simd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "339abbe78e73178762e23bea9dfd08e697eb3f3301cd4be981c0f78ba5859195" +dependencies = [ + "outref", + "vsimd", +] + [[package]] name = "base64ct" version = "1.7.3" @@ -422,6 +797,16 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +[[package]] +name = "bytes-utils" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" +dependencies = [ + "bytes", + "either", +] + [[package]] name = "castaway" version = "0.2.3" @@ -1280,6 +1665,25 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" +[[package]] +name = "h2" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0beca50380b1fc32983fc1cb4587bfa4bb9e78fc259aad4a0032d2080309222d" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 0.2.12", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "h2" version = "0.4.10" @@ -1291,7 +1695,7 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http", + "http 1.3.1", "indexmap", "slab", "tokio", @@ -1419,6 +1823,17 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "http" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http" version = "1.3.1" @@ -1430,6 +1845,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-body" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" +dependencies = [ + "bytes", + "http 0.2.12", + "pin-project-lite", +] + [[package]] name = "http-body" version = "1.0.1" @@ -1437,7 +1863,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http", + "http 1.3.1", ] [[package]] @@ -1448,8 +1874,8 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "pin-project-lite", ] @@ -1465,6 +1891,30 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "hyper" +version = "0.14.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2 0.3.27", + "http 0.2.12", + "http-body 0.4.6", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", + "want", +] + [[package]] name = "hyper" version = "1.6.0" @@ -1474,9 +1924,9 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "h2", - "http", - "http-body", + "h2 0.4.10", + "http 1.3.1", + "http-body 1.0.1", "httparse", "httpdate", "itoa", @@ -1486,19 +1936,36 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http 0.2.12", + "hyper 0.14.32", + "log", + "rustls 0.21.12", + "rustls-native-certs 0.6.3", + "tokio", + "tokio-rustls 0.24.1", +] + [[package]] name = "hyper-rustls" version = "0.27.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ - "http", - "hyper", + "http 1.3.1", + "hyper 1.6.0", "hyper-util", - "rustls", + "rustls 0.23.27", + "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.2", "tower-service", ] @@ -1510,7 +1977,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes", "http-body-util", - "hyper", + "hyper 1.6.0", "hyper-util", "native-tls", "tokio", @@ -1527,9 +1994,9 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http", - "http-body", - "hyper", + "http 1.3.1", + "http-body 1.0.1", + "hyper 1.6.0", "libc", "pin-project-lite", "socket2", @@ -2244,6 +2711,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "outref" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" + [[package]] name = "overload" version = "0.1.1" @@ -2384,6 +2857,8 @@ version = "0.1.30" dependencies = [ "arc-swap", "async-trait", + "aws-config", + "aws-sdk-rds", "base64 0.22.1", "bytes", "cc", @@ -2396,7 +2871,7 @@ dependencies = [ "futures", "hickory-resolver", "http-body-util", - "hyper", + "hyper 1.6.0", "hyper-util", "indexmap", "lazy_static", @@ -2416,7 +2891,7 @@ dependencies = [ "regex", "rmp-serde", "rust_decimal", - "rustls-native-certs", + "rustls-native-certs 0.8.1", "rustls-pki-types", "scram", "semver", @@ -2429,7 +2904,7 @@ dependencies = [ "thiserror 2.0.12", "tikv-jemallocator", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.2", "tokio-util", "toml", "tracing", @@ -2493,7 +2968,7 @@ dependencies = [ "pgdog-vector", "rust_decimal", "serde", - "thiserror 1.0.69", + "thiserror 2.0.12", "uuid", ] @@ -2992,6 +3467,12 @@ dependencies = [ "regex-syntax 0.8.5", ] +[[package]] +name = "regex-lite" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab834c73d247e67f4fae452806d17d3c7501756d98c8808d7c9c7aa7d18f973" + [[package]] name = "regex-syntax" version = "0.6.29" @@ -3024,12 +3505,12 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", - "http", - "http-body", + "h2 0.4.10", + "http 1.3.1", + "http-body 1.0.1", "http-body-util", - "hyper", - "hyper-rustls", + "hyper 1.6.0", + "hyper-rustls 0.27.7", "hyper-tls", "hyper-util", "ipnet", @@ -3040,7 +3521,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls-pemfile", + "rustls-pemfile 2.2.0", "serde", "serde_json", "serde_urlencoded", @@ -3179,7 +3660,7 @@ dependencies = [ "sqlx", "tokio", "tokio-postgres", - "tokio-rustls", + "tokio-rustls 0.26.2", "uuid", ] @@ -3253,6 +3734,18 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "rustls" +version = "0.21.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" +dependencies = [ + "log", + "ring 0.17.14", + "rustls-webpki 0.101.7", + "sct", +] + [[package]] name = "rustls" version = "0.23.27" @@ -3263,11 +3756,23 @@ dependencies = [ "log", "once_cell", "rustls-pki-types", - "rustls-webpki", + "rustls-webpki 0.103.3", "subtle", "zeroize", ] +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile 1.0.4", + "schannel", + "security-framework 2.11.1", +] + [[package]] name = "rustls-native-certs" version = "0.8.1" @@ -3280,6 +3785,15 @@ dependencies = [ "security-framework 3.2.0", ] +[[package]] +name = "rustls-pemfile" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +dependencies = [ + "base64 0.21.7", +] + [[package]] name = "rustls-pemfile" version = "2.2.0" @@ -3298,6 +3812,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring 0.17.14", + "untrusted 0.9.0", +] + [[package]] name = "rustls-webpki" version = "0.103.3" @@ -3362,6 +3886,16 @@ dependencies = [ "ring 0.16.20", ] +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring 0.17.14", + "untrusted 0.9.0", +] + [[package]] name = "sdd" version = "3.0.8" @@ -4131,6 +4665,7 @@ dependencies = [ "powerfmt", "serde", "time-core", + "time-macros", ] [[package]] @@ -4139,6 +4674,16 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" +[[package]] +name = "time-macros" +version = "0.2.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tinystr" version = "0.8.1" @@ -4229,13 +4774,23 @@ dependencies = [ "whoami", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls 0.21.12", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" dependencies = [ - "rustls", + "rustls 0.23.27", "tokio", ] @@ -4486,6 +5041,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -4527,6 +5088,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "vsimd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" + [[package]] name = "vtparse" version = "0.6.2" @@ -5174,6 +5741,12 @@ dependencies = [ "tap", ] +[[package]] +name = "xmlparser" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" + [[package]] name = "yoke" version = "0.8.0" diff --git a/README.md b/README.md index 5cc391751..1098aa95e 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,25 @@ password = "hunter2" If a database in `pgdog.toml` doesn't have a user in `users.toml`, the connection pool for that database will not be created and users won't be able to connect. +### RDS IAM backend authentication + +PgDog can keep client-to-proxy authentication unchanged while using AWS RDS IAM tokens for proxy-to-PostgreSQL authentication on a per-user basis. + +```toml +[[users]] +name = "alice" +database = "pgdog" +password = "client-password" +server_auth = "rds_iam" +# Optional; PgDog infers region from *.region.rds.amazonaws.com(.cn) hostnames when omitted. +# server_iam_region = "us-east-1" +``` + +When any user has `server_auth = "rds_iam"`: + +- `general.tls_verify` must not be `"disabled"`. +- `general.passthrough_auth` must be `"disabled"`. + If you'd like to try it out locally, create the database and user like so: ```sql diff --git a/example.pgdog.toml b/example.pgdog.toml index 9e44070cf..631587d75 100644 --- a/example.pgdog.toml +++ b/example.pgdog.toml @@ -120,6 +120,8 @@ tls_client_required = false # - prefer # - verify-ca # - verify-full +# NOTE: if any user sets `server_auth = "rds_iam"` in users.toml, +# this cannot be "disabled". tls_verify = "disabled" # Path to PEM-encoded certificate bundle to use for Postgres server @@ -197,6 +199,8 @@ query_cache_limit = 1_000 # - enabled (requires TLS) # - enabled_plain # +# NOTE: `passthrough_auth` cannot be enabled when using +# per-user backend `server_auth = "rds_iam"`. passthrough_auth = "disabled" # How long to wait for Postgres server connections to be created by the pool before diff --git a/example.users.toml b/example.users.toml index 1603b71ba..9aa05dce1 100644 --- a/example.users.toml +++ b/example.users.toml @@ -10,3 +10,9 @@ password = "pgdog" name = "pgdog" database = "pgdog_sharded" password = "pgdog" + +# Example: backend authentication with AWS RDS IAM token generation. +# PgDog still authenticates the client as configured by `general.auth_type`; +# this only affects how PgDog authenticates to PostgreSQL servers. +# server_auth = "rds_iam" +# server_iam_region = "us-east-1" # optional; auto-inferred from RDS hostname when omitted diff --git a/pgdog-config/src/core.rs b/pgdog-config/src/core.rs index 64199dbb5..eadef9cd8 100644 --- a/pgdog-config/src/core.rs +++ b/pgdog-config/src/core.rs @@ -15,12 +15,12 @@ use crate::{ use super::database::Database; use super::error::Error; use super::general::General; -use super::networking::{MultiTenant, Tcp}; +use super::networking::{MultiTenant, Tcp, TlsVerifyMode}; use super::pooling::PoolerMode; use super::replication::{MirrorConfig, Mirroring, ReplicaLag, Replication}; use super::rewrite::Rewrite; use super::sharding::{ManualQuery, OmnishardedTables, ShardedMapping, ShardedTable}; -use super::users::{Admin, Plugin, Users}; +use super::users::{Admin, Plugin, ServerAuth, Users}; #[derive(Debug, Clone, PartialEq)] pub struct ConfigAndUsers { @@ -57,8 +57,7 @@ impl ConfigAndUsers { } let mut users: Users = if let Ok(users) = read_to_string(users_path) { - let mut users: Users = toml::from_str(&users)?; - users.check(&config); + let users: Users = toml::from_str(&users)?; info!("loaded \"{}\"", users_path.display()); users } else { @@ -82,12 +81,49 @@ impl ConfigAndUsers { warn!("admin password has been randomly generated"); } - Ok(ConfigAndUsers { + let mut config_and_users = ConfigAndUsers { config, users, config_path: config_path.to_owned(), users_path: users_path.to_owned(), - }) + }; + + config_and_users.check()?; + + Ok(config_and_users) + } + + pub fn check(&mut self) -> Result<(), Error> { + self.config.check(); + self.users.check(&self.config); + self.validate_server_auth()?; + Ok(()) + } + + fn validate_server_auth(&self) -> Result<(), Error> { + let has_rds_iam_user = self + .users + .users + .iter() + .any(|user| user.server_auth == ServerAuth::RdsIam); + + if !has_rds_iam_user { + return Ok(()); + } + + if self.config.general.passthrough_auth != PassthoughAuth::Disabled { + return Err(Error::ParseError( + "\"passthrough_auth\" must be \"disabled\" when any user has \"server_auth = \\\"rds_iam\\\"\"".into(), + )); + } + + if self.config.general.tls_verify == TlsVerifyMode::Disabled { + return Err(Error::ParseError( + "\"tls_verify\" cannot be \"disabled\" when any user has \"server_auth = \\\"rds_iam\\\"\"".into(), + )); + } + + Ok(()) } /// Prepared statements are enabled. @@ -1179,4 +1215,40 @@ shard = 0 assert_eq!(dest.host, "source-host"); assert_eq!(dest.port, 5432); } + + #[test] + fn test_rds_iam_rejects_passthrough_auth() { + let mut config = ConfigAndUsers::default(); + config.config.general.passthrough_auth = PassthoughAuth::EnabledPlain; + config.config.general.tls_verify = TlsVerifyMode::VerifyFull; + config.users.users.push(crate::User { + name: "alice".into(), + database: "db".into(), + password: Some("secret".into()), + server_auth: ServerAuth::RdsIam, + ..Default::default() + }); + + let err = config.check().unwrap_err().to_string(); + assert!(err.contains("passthrough_auth")); + assert!(err.contains("rds_iam")); + } + + #[test] + fn test_rds_iam_rejects_tls_verify_disabled() { + let mut config = ConfigAndUsers::default(); + config.config.general.tls_verify = TlsVerifyMode::Disabled; + config.config.general.passthrough_auth = PassthoughAuth::Disabled; + config.users.users.push(crate::User { + name: "alice".into(), + database: "db".into(), + password: Some("secret".into()), + server_auth: ServerAuth::RdsIam, + ..Default::default() + }); + + let err = config.check().unwrap_err().to_string(); + assert!(err.contains("tls_verify")); + assert!(err.contains("rds_iam")); + } } diff --git a/pgdog-config/src/lib.rs b/pgdog-config/src/lib.rs index a08b302c5..3eec9c54a 100644 --- a/pgdog-config/src/lib.rs +++ b/pgdog-config/src/lib.rs @@ -33,7 +33,7 @@ pub use replication::*; pub use rewrite::{Rewrite, RewriteMode}; pub use sharding::*; pub use system_catalogs::system_catalogs; -pub use users::{Admin, Plugin, User, Users}; +pub use users::{Admin, Plugin, ServerAuth, User, Users}; use std::time::Duration; diff --git a/pgdog-config/src/users.rs b/pgdog-config/src/users.rs index ab36f5af0..f5bffd4dd 100644 --- a/pgdog-config/src/users.rs +++ b/pgdog-config/src/users.rs @@ -77,6 +77,25 @@ impl Users { } } +/// Backend authentication mode used by PgDog for server connections. +#[derive( + Serialize, Deserialize, Debug, Clone, Copy, Default, PartialEq, Eq, Ord, PartialOrd, Hash, +)] +#[serde(rename_all = "snake_case")] +pub enum ServerAuth { + /// Use configured static password. + #[default] + Password, + /// Generate an AWS RDS IAM auth token per connection attempt. + RdsIam, +} + +impl ServerAuth { + pub fn rds_iam(&self) -> bool { + matches!(self, Self::RdsIam) + } +} + /// User allowed to connect to pgDog. #[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq, Eq, Ord, PartialOrd)] #[serde(deny_unknown_fields)] @@ -104,6 +123,11 @@ pub struct User { pub server_user: Option, /// Server password. pub server_password: Option, + /// Backend auth mode for server connections. + #[serde(default)] + pub server_auth: ServerAuth, + /// Optional region override for RDS IAM token generation. + pub server_iam_region: Option, /// Statement timeout. pub statement_timeout: Option, /// Relication mode. @@ -253,4 +277,36 @@ mod tests { .unwrap(); assert_eq!(bob_source.password(), "pass4"); } + + #[test] + fn test_user_server_auth_defaults_to_password() { + let source = r#" +[[users]] +name = "alice" +database = "db" +password = "secret" +"#; + + let users: Users = toml::from_str(source).unwrap(); + let user = users.users.first().unwrap(); + assert_eq!(user.server_auth, ServerAuth::Password); + assert!(user.server_iam_region.is_none()); + } + + #[test] + fn test_user_server_auth_rds_iam_with_region() { + let source = r#" +[[users]] +name = "alice" +database = "db" +password = "secret" +server_auth = "rds_iam" +server_iam_region = "us-east-1" +"#; + + let users: Users = toml::from_str(source).unwrap(); + let user = users.users.first().unwrap(); + assert_eq!(user.server_auth, ServerAuth::RdsIam); + assert_eq!(user.server_iam_region.as_deref(), Some("us-east-1")); + } } diff --git a/pgdog/Cargo.toml b/pgdog/Cargo.toml index 18f820478..1bedf5fc0 100644 --- a/pgdog/Cargo.toml +++ b/pgdog/Cargo.toml @@ -63,6 +63,8 @@ hickory-resolver = "0.25.2" lazy_static = "1" dashmap = "6" derive_builder = "0.20.2" +aws-config = { version = "1", features = ["behavior-version-latest"] } +aws-sdk-rds = "1" pgdog-config = { path = "../pgdog-config" } pgdog-vector = { path = "../pgdog-vector" } pgdog-stats = { path = "../pgdog-stats" } diff --git a/pgdog/src/backend/auth/mod.rs b/pgdog/src/backend/auth/mod.rs new file mode 100644 index 000000000..9fb005282 --- /dev/null +++ b/pgdog/src/backend/auth/mod.rs @@ -0,0 +1 @@ +pub mod rds_iam; diff --git a/pgdog/src/backend/auth/rds_iam.rs b/pgdog/src/backend/auth/rds_iam.rs new file mode 100644 index 000000000..cc29ff20f --- /dev/null +++ b/pgdog/src/backend/auth/rds_iam.rs @@ -0,0 +1,170 @@ +use aws_config::{BehaviorVersion, Region}; +use aws_sdk_rds::auth_token::{AuthTokenGenerator, Config as AuthTokenConfig}; + +use crate::backend::{pool::Address, Error}; + +fn infer_region_from_rds_host(host: &str) -> Option { + let host = host.to_ascii_lowercase(); + let labels = host.split('.').collect::>(); + let rds = labels.iter().position(|label| *label == "rds")?; + + if rds == 0 { + return None; + } + + match labels.get(rds + 1..) { + // *.region.rds.amazonaws.com + Some(["amazonaws", "com"]) => {} + // *.region.rds.amazonaws.com.cn + Some(["amazonaws", "com", "cn"]) => {} + _ => return None, + } + + let region = labels.get(rds - 1)?; + if region.is_empty() { + return None; + } + + Some((*region).to_owned()) +} + +fn resolve_region(addr: &Address) -> Result { + if let Some(region) = addr.server_iam_region.as_ref() { + if !region.is_empty() { + return Ok(region.clone()); + } + } + + infer_region_from_rds_host(&addr.host).ok_or_else(|| { + Error::RdsIamToken(format!( + "unable to infer AWS region from host \"{}\"; set \"server_iam_region\"", + addr.host + )) + }) +} + +pub async fn token(addr: &Address) -> Result { + #[cfg(test)] + if let Some(token) = test_token_override() { + return Ok(token); + } + + let region = resolve_region(addr)?; + let sdk_config = aws_config::load_defaults(BehaviorVersion::latest()).await; + + let config = AuthTokenConfig::builder() + .hostname(addr.host.as_str()) + .port(addr.port.into()) + .username(addr.user.as_str()) + .region(Region::new(region.clone())) + .build() + .map_err(|error| { + Error::RdsIamToken(format!( + "failed to build RDS IAM token config for {}@{}:{} in region {}: {}", + addr.user, addr.host, addr.port, region, error + )) + })?; + + AuthTokenGenerator::new(config) + .auth_token(&sdk_config) + .await + .map(|token| token.to_string()) + .map_err(|error| { + Error::RdsIamToken(format!( + "failed to generate RDS IAM token for {}@{}:{} in region {}: {}", + addr.user, addr.host, addr.port, region, error + )) + }) +} + +#[cfg(test)] +fn test_token_override() -> Option { + TEST_TOKEN_OVERRIDE.lock().clone() +} + +#[cfg(test)] +pub(crate) fn set_test_token_override(token: Option) { + *TEST_TOKEN_OVERRIDE.lock() = token; +} + +#[cfg(test)] +static TEST_TOKEN_OVERRIDE: once_cell::sync::Lazy>> = + once_cell::sync::Lazy::new(|| parking_lot::Mutex::new(None)); + +#[cfg(test)] +mod tests { + use std::env; + + use crate::backend::pool::Address; + use crate::config::ServerAuth; + + use super::*; + + struct EnvVarGuard { + key: &'static str, + previous: Option, + } + + impl EnvVarGuard { + fn set(key: &'static str, value: &str) -> Self { + let previous = env::var(key).ok(); + env::set_var(key, value); + Self { key, previous } + } + } + + impl Drop for EnvVarGuard { + fn drop(&mut self) { + if let Some(previous) = self.previous.take() { + env::set_var(self.key, previous); + } else { + env::remove_var(self.key); + } + } + } + + #[test] + fn test_infer_region_commercial_endpoint() { + let region = infer_region_from_rds_host("db.cluster-abc123.us-east-1.rds.amazonaws.com"); + assert_eq!(region.as_deref(), Some("us-east-1")); + } + + #[test] + fn test_infer_region_china_endpoint() { + let region = + infer_region_from_rds_host("db.cluster-abc123.cn-north-1.rds.amazonaws.com.cn"); + assert_eq!(region.as_deref(), Some("cn-north-1")); + } + + #[test] + fn test_infer_region_fails_for_custom_hostname() { + let region = infer_region_from_rds_host("postgres.internal.example.com"); + assert!(region.is_none()); + } + + #[tokio::test] + async fn test_token_contains_expected_query_fields() { + let _access_key = EnvVarGuard::set("AWS_ACCESS_KEY_ID", "AKIDEXAMPLE"); + let _secret_key = EnvVarGuard::set("AWS_SECRET_ACCESS_KEY", "SECRETEXAMPLE"); + let _session = EnvVarGuard::set("AWS_SESSION_TOKEN", "SESSIONEXAMPLE"); + + let addr = Address { + host: "db.cluster-abc123.us-east-1.rds.amazonaws.com".into(), + port: 5432, + database_name: "postgres".into(), + user: "db_user".into(), + password: String::new(), + database_number: 0, + server_auth: ServerAuth::RdsIam, + server_iam_region: Some("us-east-1".into()), + }; + + let token = token(&addr).await.unwrap(); + assert!(token.starts_with( + "db.cluster-abc123.us-east-1.rds.amazonaws.com:5432/?Action=connect&DBUser=db_user" + )); + assert!(token.contains("X-Amz-Algorithm=AWS4-HMAC-SHA256")); + assert!(token.contains("X-Amz-Credential=")); + assert!(token.contains("X-Amz-Signature=")); + } +} diff --git a/pgdog/src/backend/error.rs b/pgdog/src/backend/error.rs index dee0c2c91..f1dc18c74 100644 --- a/pgdog/src/backend/error.rs +++ b/pgdog/src/backend/error.rs @@ -117,6 +117,9 @@ pub enum Error { #[error("could not resolve to any address for hostname {0}")] DnsResolutionFailed(String), + #[error("RDS IAM token generation failed: {0}")] + RdsIamToken(String), + #[error("pub/sub channel disabled")] PubSubDisabled, diff --git a/pgdog/src/backend/mod.rs b/pgdog/src/backend/mod.rs index 2bc400685..0ced5523c 100644 --- a/pgdog/src/backend/mod.rs +++ b/pgdog/src/backend/mod.rs @@ -1,5 +1,6 @@ //! pgDog backend managers connections to PostgreSQL. +pub mod auth; pub mod connect_reason; pub mod databases; pub mod disconnect_reason; diff --git a/pgdog/src/backend/pool/address.rs b/pgdog/src/backend/pool/address.rs index 63b141163..41ab5ee1c 100644 --- a/pgdog/src/backend/pool/address.rs +++ b/pgdog/src/backend/pool/address.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use url::Url; use crate::backend::{pool::dns_cache::DnsCache, Error}; -use crate::config::{config, Database, User}; +use crate::config::{config, Database, ServerAuth, User}; /// Server address. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, Eq, Hash)] @@ -20,6 +20,11 @@ pub struct Address { pub user: String, /// Password. pub password: String, + /// Server auth mode for backend connections. + #[serde(default)] + pub server_auth: ServerAuth, + /// Optional IAM region override. + pub server_iam_region: Option, /// Database number (in the config). pub database_number: usize, } @@ -27,6 +32,8 @@ pub struct Address { impl Address { /// Create new address from config values. pub fn new(database: &Database, user: &User, database_number: usize) -> Self { + let server_auth = user.server_auth; + Address { host: database.host.clone(), port: database.port, @@ -42,17 +49,30 @@ impl Address { } else { user.name.clone() }, - password: if let Some(password) = database.password.clone() { - password - } else if let Some(password) = user.server_password.clone() { - password + password: if server_auth.rds_iam() { + String::new() } else { - user.password().to_string() + if let Some(password) = database.password.clone() { + password + } else if let Some(password) = user.server_password.clone() { + password + } else { + user.password().to_string() + } }, + server_auth, + server_iam_region: user.server_iam_region.clone(), database_number, } } + pub async fn auth_secret(&self) -> Result { + match self.server_auth { + ServerAuth::Password => Ok(self.password.clone()), + ServerAuth::RdsIam => crate::backend::auth::rds_iam::token(self).await, + } + } + pub async fn addr(&self) -> Result { let dns_cache_override_enabled = config().config.general.dns_ttl().is_some(); @@ -77,6 +97,8 @@ impl Address { user: "pgdog".into(), password: "pgdog".into(), database_name: "pgdog".into(), + server_auth: ServerAuth::Password, + server_iam_region: None, database_number: 0, } } @@ -108,6 +130,8 @@ impl TryFrom for Address { password, user, database_name, + server_auth: ServerAuth::Password, + server_iam_region: None, database_number: 0, }) } @@ -152,6 +176,32 @@ mod test { assert_eq!(address.password, "hunter3"); } + #[test] + fn test_rds_iam_does_not_use_static_password() { + let database = Database { + name: "pgdog".into(), + host: "127.0.0.1".into(), + port: 6432, + password: Some("db-level-pass".into()), + ..Default::default() + }; + + let user = User { + name: "pgdog".into(), + password: Some("user-pass".into()), + server_password: Some("server-pass".into()), + server_auth: ServerAuth::RdsIam, + server_iam_region: Some("us-east-1".into()), + database: "pgdog".into(), + ..Default::default() + }; + + let address = Address::new(&database, &user, 0); + assert_eq!(address.password, ""); + assert_eq!(address.server_auth, ServerAuth::RdsIam); + assert_eq!(address.server_iam_region.as_deref(), Some("us-east-1")); + } + #[test] fn test_addr_from_url() { let addr = @@ -162,5 +212,27 @@ mod test { assert_eq!(addr.database_name, "pgdb"); assert_eq!(addr.user, "user"); assert_eq!(addr.password, "password"); + assert_eq!(addr.server_auth, ServerAuth::Password); + assert!(addr.server_iam_region.is_none()); + } + + #[tokio::test] + async fn test_auth_secret_password_mode() { + let addr = Address::new_test(); + assert_eq!(addr.auth_secret().await.unwrap(), "pgdog"); + } + + #[tokio::test] + async fn test_auth_secret_rds_iam_mode_uses_generator() { + let mut addr = Address::new_test(); + addr.server_auth = ServerAuth::RdsIam; + addr.server_iam_region = Some("us-east-1".into()); + addr.password = "wrong".into(); + + crate::backend::auth::rds_iam::set_test_token_override(Some("token-from-iam".into())); + let secret = addr.auth_secret().await.unwrap(); + crate::backend::auth::rds_iam::set_test_token_override(None); + + assert_eq!(secret, "token-from-iam"); } } diff --git a/pgdog/src/backend/schema/sync/pg_dump.rs b/pgdog/src/backend/schema/sync/pg_dump.rs index 446307e09..a73a226a3 100644 --- a/pgdog/src/backend/schema/sync/pg_dump.rs +++ b/pgdog/src/backend/schema/sync/pg_dump.rs @@ -23,7 +23,7 @@ use crate::{ replication::{publisher::PublicationTable, status::SchemaStatement}, Cluster, }, - config::config, + config::{config, ServerAuth}, frontend::router::parser::{sequence::Sequence, Column, Table}, }; @@ -118,6 +118,31 @@ pub struct PgDump { publication: String, } +fn build_pg_dump_command( + pg_dump_path: &str, + addr: &backend::pool::Address, + auth_secret: &str, +) -> Command { + let mut command = Command::new(pg_dump_path); + command + .arg("--schema-only") + .arg("-h") + .arg(&addr.host) + .arg("-p") + .arg(addr.port.to_string()) + .arg("-U") + .arg(&addr.user) + .env("PGPASSWORD", auth_secret) + .arg("-d") + .arg(&addr.database_name); + + if addr.server_auth == ServerAuth::RdsIam { + command.env("PGSSLMODE", "require"); + } + + command +} + impl PgDump { pub fn new(source: &Cluster, publication: &str) -> Self { Self { @@ -185,19 +210,9 @@ impl PgDump { .to_str() .unwrap_or("pg_dump"); - let output = Command::new(pg_dump_path) - .arg("--schema-only") - .arg("-h") - .arg(&addr.host) - .arg("-p") - .arg(addr.port.to_string()) - .arg("-U") - .arg(&addr.user) - .env("PGPASSWORD", &addr.password) - .arg("-d") - .arg(&addr.database_name) - .output() - .await?; + let auth_secret = addr.auth_secret().await?; + let mut command = build_pg_dump_command(pg_dump_path, &addr, &auth_secret); + let output = command.output().await?; if !output.status.success() { let err = from_utf8(&output.stderr)?; @@ -1146,6 +1161,10 @@ impl PgDumpOutput { #[cfg(test)] mod test { + use std::ffi::OsStr; + + use crate::config::ServerAuth; + use super::*; #[tokio::test] @@ -1154,6 +1173,43 @@ mod test { let _pg_dump = PgDump::new(&cluster, "test_pg_dump_execute"); } + #[test] + fn test_build_pg_dump_command_sets_password_env() { + let addr = backend::pool::Address::new_test(); + let command = build_pg_dump_command("pg_dump", &addr, "secret"); + + let env = command + .as_std() + .get_envs() + .find(|(key, _)| *key == OsStr::new("PGPASSWORD")) + .and_then(|(_, value)| value); + + assert_eq!(env, Some(OsStr::new("secret"))); + + let sslmode = command + .as_std() + .get_envs() + .find(|(key, _)| *key == OsStr::new("PGSSLMODE")) + .and_then(|(_, value)| value); + + assert_eq!(sslmode, None); + } + + #[test] + fn test_build_pg_dump_command_sets_tls_for_rds_iam() { + let mut addr = backend::pool::Address::new_test(); + addr.server_auth = ServerAuth::RdsIam; + let command = build_pg_dump_command("pg_dump", &addr, "token"); + + let sslmode = command + .as_std() + .get_envs() + .find(|(key, _)| *key == OsStr::new("PGSSLMODE")) + .and_then(|(_, value)| value); + + assert_eq!(sslmode, Some(OsStr::new("require"))); + } + #[test] fn test_specific_dump() { let dump = r#" diff --git a/pgdog/src/backend/server.rs b/pgdog/src/backend/server.rs index 666ae4f24..8addc1864 100644 --- a/pgdog/src/backend/server.rs +++ b/pgdog/src/backend/server.rs @@ -164,7 +164,8 @@ impl Server { stream.flush().await?; // Perform authentication. - let mut scram = Client::new(&addr.user, &addr.password); + let auth_secret = addr.auth_secret().await?; + let mut scram = Client::new(&addr.user, &auth_secret); let mut auth_type = AuthType::Trust; loop { let message = stream.read().await?; @@ -180,7 +181,7 @@ impl Server { match auth { Authentication::Ok => break, Authentication::ClearTextPassword => { - let password = Password::new_password(&addr.password); + let password = Password::new_password(&auth_secret); stream.send_flush(&password).await?; } Authentication::Sasl(_) => { @@ -200,7 +201,7 @@ impl Server { } Authentication::Md5(salt) => { auth_type = AuthType::Md5; - let client = md5::Client::new_salt(&addr.user, &addr.password, &salt)?; + let client = md5::Client::new_salt(&addr.user, &auth_secret, &salt)?; stream.send_flush(&client.response()).await?; } } @@ -1024,10 +1025,30 @@ impl Drop for Server { // Used for testing. #[cfg(test)] pub mod test { + use bytes::{BufMut, BytesMut}; + use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpListener, + }; + use crate::{config::Memory, frontend::PreparedStatements, net::*}; use super::{Error, *}; + async fn read_password_message(stream: &mut tokio::net::TcpStream) -> Password { + let code = stream.read_u8().await.unwrap(); + let len = stream.read_i32().await.unwrap(); + let mut payload = vec![0; (len - 4) as usize]; + stream.read_exact(&mut payload).await.unwrap(); + + let mut bytes = BytesMut::with_capacity(len as usize + 1); + bytes.put_u8(code); + bytes.put_i32(len); + bytes.extend_from_slice(&payload); + + Password::from_bytes(bytes.freeze()).unwrap() + } + impl Default for Server { fn default() -> Self { let id = BackendKeyData::new(); @@ -1092,6 +1113,63 @@ pub mod test { .unwrap() } + #[tokio::test] + async fn test_connect_rds_iam_uses_dynamic_token_not_static_password() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let expected_secret = "iam-token-for-test".to_string(); + let server_task = tokio::spawn({ + let expected_secret = expected_secret.clone(); + async move { + let (mut socket, _) = listener.accept().await.unwrap(); + + let startup = Startup::from_stream(&mut socket).await.unwrap(); + let startup = if matches!(startup, Startup::Ssl) { + socket.write_all(b"N").await.unwrap(); + Startup::from_stream(&mut socket).await.unwrap() + } else { + startup + }; + assert!(matches!(startup, Startup::Startup { .. })); + + socket + .write_all(&Authentication::ClearTextPassword.to_bytes().unwrap()) + .await + .unwrap(); + + let password = read_password_message(&mut socket).await; + assert_eq!(password.password(), Some(expected_secret.as_str())); + + socket + .write_all(&Authentication::Ok.to_bytes().unwrap()) + .await + .unwrap(); + socket + .write_all(&BackendKeyData::new().to_bytes().unwrap()) + .await + .unwrap(); + socket + .write_all(&ReadyForQuery::idle().to_bytes().unwrap()) + .await + .unwrap(); + } + }); + + let mut addr = Address::new_test(); + addr.port = port; + addr.server_auth = crate::config::ServerAuth::RdsIam; + addr.server_iam_region = Some("us-east-1".into()); + addr.password = "wrong-password".into(); + + crate::backend::auth::rds_iam::set_test_token_override(Some(expected_secret)); + let result = Server::connect(&addr, ServerOptions::default(), ConnectReason::Other).await; + crate::backend::auth::rds_iam::set_test_token_override(None); + + let server = result.unwrap(); + drop(server); + server_task.await.unwrap(); + } + #[tokio::test] async fn test_simple_query() { let mut server = test_server().await; diff --git a/pgdog/src/config/mod.rs b/pgdog/src/config/mod.rs index c20aa4d37..57e0cb107 100644 --- a/pgdog/src/config/mod.rs +++ b/pgdog/src/config/mod.rs @@ -26,7 +26,7 @@ pub use pgdog_config::auth::{AuthType, PassthoughAuth}; pub use pgdog_config::{LoadBalancingStrategy, ReadWriteSplit, ReadWriteStrategy}; pub use pooling::{ConnectionRecovery, PoolerMode, PreparedStatements}; pub use rewrite::{Rewrite, RewriteMode}; -pub use users::{Admin, Plugin, User, Users}; +pub use users::{Admin, Plugin, ServerAuth, User, Users}; // Re-export from sharding module pub use sharding::{ @@ -61,7 +61,7 @@ pub fn load(config: &PathBuf, users: &PathBuf) -> Result } pub fn set(mut config: ConfigAndUsers) -> Result { - config.config.check(); + config.check()?; for table in config.config.sharded_tables.iter_mut() { table.load_centroids()?; }