diff --git a/Cargo.lock b/Cargo.lock index b936247..88dda8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -331,6 +331,7 @@ version = "0.0.1" dependencies = [ "anyhow", "attestation", + "mock-tdx", "nested-tls", "ra-tls", "rcgen 0.14.7", diff --git a/crates/attestation/src/dcap.rs b/crates/attestation/src/dcap.rs index 62310dc..b7a3922 100644 --- a/crates/attestation/src/dcap.rs +++ b/crates/attestation/src/dcap.rs @@ -236,10 +236,16 @@ pub async fn verify_dcap_attestation( pub fn verify_dcap_attestation_sync( input: Vec, expected_input_data: [u8; 64], - _pccs: Pccs, + pccs: Pccs, ) -> Result { - // In tests we use mock quotes which will fail to verify let quote = Quote::parse(&input)?; + let ca = quote.ca()?; + let fmspc = hex::encode_upper(quote.fmspc()?); + let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs(); + let collateral = pccs.get_collateral_sync(fmspc, ca, now)?; + let verifier = mock_tdx::mock_dcap_verifier(); + verifier.verify(&input, &collateral, now)?; + let measurements = MultiMeasurements::from_dcap_qvl_quote("e)?; if get_quote_input_data(quote.report.clone()) != expected_input_data { return Err(DcapVerificationError::InputMismatch); @@ -396,7 +402,12 @@ mod tests { #[tokio::test] async fn test_mock_dcap_verify_uses_pccs_when_provided() { - let mock_pcs = spawn_mock_pcs_server(MockPcsConfig::default()).await.unwrap(); + let mock_pcs = spawn_mock_pcs_server(MockPcsConfig { + include_fmspcs_listing: false, + ..MockPcsConfig::default() + }) + .await + .unwrap(); let pccs = Pccs::new(Some(mock_pcs.base_url.clone())); let expected_input_data = [0xA5; 64]; let attestation_bytes = create_dcap_attestation(expected_input_data).unwrap(); diff --git a/crates/attestation/src/lib.rs b/crates/attestation/src/lib.rs index 5f8c712..0b7b2a5 100644 --- a/crates/attestation/src/lib.rs +++ b/crates/attestation/src/lib.rs @@ -326,6 +326,18 @@ impl AttestationVerifier { } } + /// Expect mock measurements used in tests, and use a PCCS + #[cfg(any(test, feature = "mock"))] + pub fn mock_with_pccs(pccs_url: String) -> Self { + Self { + measurement_policy: MeasurementPolicy::mock(), + pccs_url: None, + dump_dcap_quotes: false, + override_azure_outdated_tcb: false, + internal_pccs: Some(Pccs::new(Some(pccs_url))), + } + } + /// Resolves once the internal PCCS cache is ready to verify /// attestations /// @@ -601,6 +613,7 @@ pub enum AttestationError { #[cfg(test)] mod tests { + use mock_tdx::mock_pcs::{MockPcsConfig, spawn_mock_pcs_server}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::TcpListener, @@ -675,11 +688,17 @@ mod tests { assert_eq!(wrapped.attestation, vec![9, 8]); } - #[test] - fn mock_verifier_supports_sync_verification() { + #[tokio::test] + async fn mock_verifier_supports_sync_verification() { let input_data = [7u8; 64]; let attestation = dcap::create_dcap_attestation(input_data).unwrap(); - let verifier = AttestationVerifier::mock(); + + let mock_pcs_server = spawn_mock_pcs_server(MockPcsConfig::default()).await.unwrap(); + + let verifier = AttestationVerifier::mock_with_pccs(mock_pcs_server.base_url.clone()); + if let Some(ref pccs) = verifier.internal_pccs { + pccs.ready().await.unwrap(); + } let result = verifier.verify_attestation_sync( AttestationExchangeMessage { attestation_type: AttestationType::DcapTdx, attestation }, diff --git a/crates/attested-tls/Cargo.toml b/crates/attested-tls/Cargo.toml index 912f495..ce1568b 100644 --- a/crates/attested-tls/Cargo.toml +++ b/crates/attested-tls/Cargo.toml @@ -21,9 +21,9 @@ yasna = "0.5.2" [dev-dependencies] attestation = { workspace = true, features = ["mock"] } -rustls = { workspace = true, default-features = false, features = ["aws_lc_rs"] } - +mock-tdx = { workspace = true } nested-tls = { path = "../nested-tls" } +rustls = { workspace = true, default-features = false, features = ["aws_lc_rs"] } [lints] workspace = true diff --git a/crates/attested-tls/src/lib.rs b/crates/attested-tls/src/lib.rs index 479487d..ab26cf2 100644 --- a/crates/attested-tls/src/lib.rs +++ b/crates/attested-tls/src/lib.rs @@ -1033,6 +1033,7 @@ pub enum AttestedTlsError { mod tests { use std::{io::Cursor, sync::Arc}; + use mock_tdx::mock_pcs::{MockPcsConfig, spawn_mock_pcs_server}; use ra_tls::rcgen::{ BasicConstraints, CertificateParams, @@ -1084,6 +1085,24 @@ mod tests { ) } + async fn ready_mock_attested_verifier( + root_store: Option, + provider: Arc, + ) -> AttestedCertificateVerifier { + let mock_pcs_server = spawn_mock_pcs_server(MockPcsConfig::default()).await.unwrap(); + let verifier = AttestationVerifier::mock_with_pccs(mock_pcs_server.base_url.clone()); + if let Some(ref pccs) = verifier.internal_pccs { + pccs.ready().await.unwrap(); + } + + let mut builder = + AttestedCertificateVerifier::build(verifier).with_crypto_provider(provider); + if let Some(root_store) = root_store { + builder = builder.with_root_cert_store(root_store); + } + builder.finish().unwrap() + } + #[tokio::test] async fn certificate_resolver_creates_initial_certificate() { let provider: Arc = aws_lc_rs::default_provider().into(); @@ -1135,10 +1154,7 @@ mod tests { .finish() .unwrap(); - let verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) - .with_crypto_provider(provider.clone()) - .finish() - .unwrap(); + let verifier = ready_mock_attested_verifier(None, provider.clone()).await; let server_config = ServerConfig::builder_with_provider(provider.clone()) .with_safe_default_protocol_versions() @@ -1192,11 +1208,7 @@ mod tests { let mut roots = RootCertStore::empty(); roots.add(ca_cert).unwrap(); - let verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) - .with_crypto_provider(provider.clone()) - .with_root_cert_store(roots) - .finish() - .unwrap(); + let verifier = ready_mock_attested_verifier(Some(roots), provider.clone()).await; let server_config = ServerConfig::builder_with_provider(provider.clone()) .with_safe_default_protocol_versions() @@ -1275,14 +1287,8 @@ mod tests { .finish() .unwrap(); - let server_verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) - .with_crypto_provider(provider.clone()) - .finish() - .unwrap(); - let client_verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) - .with_crypto_provider(provider.clone()) - .finish() - .unwrap(); + let server_verifier = ready_mock_attested_verifier(None, provider.clone()).await; + let client_verifier = ready_mock_attested_verifier(None, provider.clone()).await; let server_config = ServerConfig::builder_with_provider(provider.clone()) .with_safe_default_protocol_versions() @@ -1327,10 +1333,7 @@ mod tests { .with_certificate_validity(Duration::from_secs(4)) .finish() .unwrap(); - let verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) - .with_crypto_provider(provider.clone()) - .finish() - .unwrap(); + let verifier = ready_mock_attested_verifier(None, provider.clone()).await; let server_config = ServerConfig::builder_with_provider(provider.clone()) .with_safe_default_protocol_versions() @@ -1539,7 +1542,12 @@ mod tests { .with_certificate_validity(Duration::from_secs(4)) .finish() .unwrap(); - let verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) + let mock_pcs_server = spawn_mock_pcs_server(MockPcsConfig::default()).await.unwrap(); + let verifier = AttestationVerifier::mock_with_pccs(mock_pcs_server.base_url.clone()); + if let Some(ref pccs) = verifier.internal_pccs { + pccs.ready().await.unwrap(); + } + let verifier = AttestedCertificateVerifier::build(verifier) .with_crypto_provider(provider) .with_allowed_leaf_cert_pubkey(&key_pair.public_key_der()) .finish() @@ -1674,10 +1682,7 @@ mod tests { .with_certificate_validity(Duration::from_secs(4)) .finish() .unwrap(); - let mut verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) - .with_crypto_provider(provider) - .finish() - .unwrap(); + let mut verifier = ready_mock_attested_verifier(None, provider).await; let cert = resolver.state.certificate.read().unwrap().first().unwrap().clone(); let (expected_input_data, not_after) = AttestedCertificateVerifier::cert_binding_data( @@ -1708,6 +1713,78 @@ mod tests { .unwrap(); } + #[tokio::test(flavor = "multi_thread")] + async fn sync_verifier_cache_miss_fails_then_succeeds_after_background_fetch() { + let provider: Arc = aws_lc_rs::default_provider().into(); + let key_pair = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); + let resolver = AttestedCertificateResolver::build( + "foo", + AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(), + ) + .with_crypto_provider(provider.clone()) + .with_key_pair(&key_pair) + .with_certificate_validity(Duration::from_secs(4)) + .finish() + .unwrap(); + let cert = resolver.state.certificate.read().unwrap().first().unwrap().clone(); + + // Mock PCS is set up to not list the FMSPCs, meaning the pre-warm + // wont fetch anything + let mock_pcs = spawn_mock_pcs_server(MockPcsConfig { + include_fmspcs_listing: false, + ..MockPcsConfig::default() + }) + .await + .unwrap(); + + let verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock_with_pccs( + mock_pcs.base_url.clone(), + )) + .with_crypto_provider(provider) + .finish() + .unwrap(); + + let first_result = verify_server_cert_direct( + &verifier, + &cert, + &ServerName::try_from("foo").unwrap(), + UnixTime::now(), + ); + + // Initially verification fails because the PCCS doesn't have the + // collateral associated with the quote + assert_eq!( + first_result.unwrap_err(), + Error::InvalidCertificate(CertificateError::ApplicationVerificationFailure) + ); + + // Now we wait a moment for the PCCS to fetch it in the background + for _ in 0..50 { + if verify_server_cert_direct( + &verifier, + &cert, + &ServerName::try_from("foo").unwrap(), + UnixTime::now(), + ) + .is_ok() + { + break; + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + + // Now verification succeeds + verify_server_cert_direct( + &verifier, + &cert, + &ServerName::try_from("foo").unwrap(), + UnixTime::now(), + ) + .unwrap(); + assert_eq!(mock_pcs.tcb_call_count(), 1); + assert_eq!(mock_pcs.qe_call_count(), 1); + } + /// Helper to create a private certificate authority fn test_ca() -> CaCert { let key = KeyPair::generate_for(&PKCS_ECDSA_P256_SHA256).unwrap(); diff --git a/crates/attested-tls/tests/nested_tls.rs b/crates/attested-tls/tests/nested_tls.rs index 36c2307..7110e8f 100644 --- a/crates/attested-tls/tests/nested_tls.rs +++ b/crates/attested-tls/tests/nested_tls.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use attestation::{AttestationGenerator, AttestationType, AttestationVerifier}; use attested_tls::{AttestedCertificateResolver, AttestedCertificateVerifier}; +use mock_tdx::mock_pcs::{MockPcsConfig, spawn_mock_pcs_server}; use nested_tls::{client::NestingTlsConnector, server::NestingTlsAcceptor}; use ra_tls::rcgen::{KeyPair, PKCS_ECDSA_P256_SHA256}; use rustls::{ @@ -19,7 +20,7 @@ async fn nested_tls_uses_attested_tls_for_inner_session() { let provider: Arc = aws_lc_rs::default_provider().into(); let (outer_server, outer_client) = plain_tls_config_pair(provider.clone()); let inner_server = attested_server_config("localhost", provider.clone()); - let inner_client = attested_client_config(provider.clone()); + let inner_client = attested_client_config(provider.clone()).await; let acceptor = NestingTlsAcceptor::new(Arc::new(outer_server), Arc::new(inner_server)); let connector = NestingTlsConnector::new(Arc::new(outer_client), Arc::new(inner_client)); @@ -102,8 +103,13 @@ fn attested_server_config(server_name: &str, provider: Arc) -> S } /// Create client TLS config with attestation verification -fn attested_client_config(provider: Arc) -> ClientConfig { - let verifier = AttestedCertificateVerifier::build(AttestationVerifier::mock()) +async fn attested_client_config(provider: Arc) -> ClientConfig { + let mock_pcs_server = spawn_mock_pcs_server(MockPcsConfig::default()).await.unwrap(); + let verifier = AttestationVerifier::mock_with_pccs(mock_pcs_server.base_url.clone()); + if let Some(ref pccs) = verifier.internal_pccs { + pccs.ready().await.unwrap(); + } + let verifier = AttestedCertificateVerifier::build(verifier) .with_crypto_provider(provider.clone()) .finish() .unwrap(); diff --git a/crates/mock-tdx/src/lib.rs b/crates/mock-tdx/src/lib.rs index 4ea4fbf..50a2989 100644 --- a/crates/mock-tdx/src/lib.rs +++ b/crates/mock-tdx/src/lib.rs @@ -1,4 +1,4 @@ -mod mock_pcs; +pub mod mock_pcs; use dcap_qvl::{ QuoteCollateralV3, diff --git a/crates/mock-tdx/src/mock_pcs.rs b/crates/mock-tdx/src/mock_pcs.rs index 77276c6..c9fba83 100644 --- a/crates/mock-tdx/src/mock_pcs.rs +++ b/crates/mock-tdx/src/mock_pcs.rs @@ -43,7 +43,7 @@ impl Default for MockPcsConfig { let qe_identity: Value = serde_json::from_str(&collateral.qe_identity).unwrap(); Self { - include_fmspcs_listing: false, + include_fmspcs_listing: true, tcb_next_update: tcb_info["nextUpdate"].as_str().unwrap().to_string(), qe_next_update: qe_identity["nextUpdate"].as_str().unwrap().to_string(), refreshed_tcb_next_update: None,