diff --git a/Cargo.lock b/Cargo.lock index ac8f39b6..6d80e62c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -493,7 +493,7 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "burrego" version = "0.3.4" -source = "git+https://github.com/kubewarden/policy-evaluator?tag=v0.18.1#faea40d47f9fc91728663715aac1fb3131b53f30" +source = "git+https://github.com/kubewarden/policy-evaluator?tag=v0.18.2#e363148a59cf3de6ae29b55d2a845ac3ed6f4e0e" dependencies = [ "base64 0.22.1", "chrono", @@ -1540,6 +1540,21 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -2202,6 +2217,28 @@ dependencies = [ "serde", ] +[[package]] +name = "inotify" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdd168d97690d0b8c412d6b6c10360277f4d7ee495c5d0d5d5fe0854923255cc" +dependencies = [ + "bitflags 1.3.2", + "futures-core", + "inotify-sys", + "libc", + "tokio", +] + +[[package]] +name = "inotify-sys" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" +dependencies = [ + "libc", +] + [[package]] name = "inout" version = "0.1.3" @@ -3172,12 +3209,50 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" +[[package]] +name = "openssl" +version = "0.10.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.71", +] + [[package]] name = "openssl-probe" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +[[package]] +name = "openssl-sys" +version = "0.9.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "opentelemetry" version = "0.23.0" @@ -3649,8 +3724,8 @@ checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" [[package]] name = "policy-evaluator" -version = "0.18.1" -source = "git+https://github.com/kubewarden/policy-evaluator?tag=v0.18.1#faea40d47f9fc91728663715aac1fb3131b53f30" +version = "0.18.2" +source = "git+https://github.com/kubewarden/policy-evaluator?tag=v0.18.2#e363148a59cf3de6ae29b55d2a845ac3ed6f4e0e" dependencies = [ "anyhow", "base64 0.22.1", @@ -3682,7 +3757,7 @@ dependencies = [ "validator", "wapc", "wasi-common", - "wasmparser 0.213.0", + "wasmparser 0.214.0", "wasmtime", "wasmtime-provider", "wasmtime-wasi", @@ -3690,8 +3765,8 @@ dependencies = [ [[package]] name = "policy-fetcher" -version = "0.8.7" -source = "git+https://github.com/kubewarden/policy-fetcher?tag=v0.8.7#fa98db1aad51fa3181d5259f5c251702eb18963c" +version = "0.8.8" +source = "git+https://github.com/kubewarden/policy-fetcher?tag=v0.8.8#0f31d41442390c87d55b4cb24d6249ae962d3110" dependencies = [ "async-trait", "base64 0.22.1", @@ -3730,6 +3805,7 @@ dependencies = [ "daemonize", "futures", "http-body-util", + "inotify", "itertools 0.13.0", "jemalloc_pprof", "k8s-openapi", @@ -3738,13 +3814,16 @@ dependencies = [ "mockall", "mockall_double", "num_cpus", + "openssl", "opentelemetry", "opentelemetry-otlp", "opentelemetry_sdk", "policy-evaluator", "pprof", "rayon", + "rcgen", "regex", + "reqwest", "rhai", "rstest", "rustls-pki-types", @@ -3758,6 +3837,7 @@ dependencies = [ "tikv-jemalloc-ctl", "tikv-jemallocator", "tokio", + "tokio-stream", "tower", "tower-http", "tracing", @@ -4150,6 +4230,19 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "rcgen" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "yasna", +] + [[package]] name = "redox_syscall" version = "0.5.3" @@ -4241,8 +4334,10 @@ checksum = "c7d6d2a27d57148378eb5e111173f4276ad26340ecc5c49a4a2152167a2d6a37" dependencies = [ "base64 0.22.1", "bytes", + "encoding_rs", "futures-core", "futures-util", + "h2 0.4.5", "http 1.1.0", "http-body 1.0.1", "http-body-util", @@ -5772,6 +5867,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.4" @@ -5963,9 +6064,9 @@ dependencies = [ [[package]] name = "wasmparser" -version = "0.213.0" +version = "0.214.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e48e5a90a9e0afc2990437f5600b8de682a32b18cbaaf6f2b5db185352868b6b" +checksum = "5309c1090e3e84dad0d382f42064e9933fdaedb87e468cc239f0eabea73ddcb6" dependencies = [ "ahash 0.8.11", "bitflags 2.6.0", @@ -6781,6 +6882,15 @@ dependencies = [ "tls_codec", ] +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + [[package]] name = "zerocopy" version = "0.7.35" diff --git a/Cargo.toml b/Cargo.toml index 5835a14e..8831b40e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ opentelemetry = { version = "0.23.0", default-features = false, features = [ ] } opentelemetry_sdk = { version = "0.23.0", features = ["rt-tokio"] } pprof = { version = "0.13", features = ["prost-codec"] } -policy-evaluator = { git = "https://github.com/kubewarden/policy-evaluator", tag = "v0.18.1" } +policy-evaluator = { git = "https://github.com/kubewarden/policy-evaluator", tag = "v0.18.2" } rustls-pki-types = { version = "1", features = ["alloc"] } rayon = "1.10" regex = "1.10" @@ -55,9 +55,22 @@ jemalloc_pprof = "0.4.1" tikv-jemalloc-ctl = "0.5.4" rhai = { version = "1.19.0", features = ["sync"] } +[target.'cfg(target_os = "linux")'.dependencies] +inotify = "0.10" +tokio-stream = "0.1.15" + [dev-dependencies] mockall = "0.12" rstest = "0.21" tempfile = "3.10.1" tower = { version = "0.4", features = ["util"] } http-body-util = "0.1.1" + +[target.'cfg(target_os = "linux")'.dev-dependencies] +rcgen = { version = "0.13", features = ["crypto"] } +openssl = "0.10" +reqwest = { version = "0.12", default-features = false, features = [ + "charset", + "http2", + "rustls-tls-manual-roots", +] } diff --git a/src/lib.rs b/src/lib.rs index 619c0418..b777499d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,6 +39,10 @@ use tokio::{ }; use tower_http::trace::{self, TraceLayer}; +// This is required by certificate hot reload when using inotify, which is available only on linux +#[cfg(target_os = "linux")] +use tokio_stream::StreamExt; + use crate::api::handlers::{ audit_handler, pprof_get_cpu, pprof_get_heap, readiness_handler, validate_handler, validate_raw_handler, @@ -46,7 +50,7 @@ use crate::api::handlers::{ use crate::api::state::ApiServerState; use crate::evaluation::precompiled_policy::{PrecompiledPolicies, PrecompiledPolicy}; use crate::policy_downloader::{Downloader, FetchedPolicies}; -use config::Config; +use config::{Config, TlsConfig}; use tikv_jemallocator::Jemalloc; @@ -193,9 +197,7 @@ impl PolicyServer { }); let tls_config = if let Some(tls_config) = config.tls_config { - let rustls_config = - RustlsConfig::from_pem_file(tls_config.cert_file, tls_config.key_file).await?; - Some(rustls_config) + Some(create_tls_config_and_watch_certificate_changes(tls_config).await?) } else { None }; @@ -269,6 +271,88 @@ impl PolicyServer { } } +/// There's no watching of the certificate files on non-linux platforms +/// since we rely on inotify to watch for changes +#[cfg(not(target_os = "linux"))] +async fn create_tls_config_and_watch_certificate_changes( + tls_config: TlsConfig, +) -> Result { + let cfg = RustlsConfig::from_pem_file(tls_config.cert_file, tls_config.key_file).await?; + Ok(cfg) +} + +/// Return the RustlsConfig and watch for changes in the certificate files +/// using inotify. +/// When a both the certificate and its key are changed, the RustlsConfig is reloaded, +/// causing the https server to use the new certificate. +/// +/// Relying on inotify is only available on linux +#[cfg(target_os = "linux")] +async fn create_tls_config_and_watch_certificate_changes( + tls_config: TlsConfig, +) -> Result { + let cert_file = tls_config.cert_file.clone(); + let key_file = tls_config.key_file.clone(); + + let rust_config = + RustlsConfig::from_pem_file(tls_config.cert_file, tls_config.key_file).await?; + let reloadable_rust_config = rust_config.clone(); + + let inotify = + inotify::Inotify::init().map_err(|e| anyhow!("Cannot initialize inotify: {e}"))?; + let cert_watch = inotify + .watches() + .add(cert_file.clone(), inotify::WatchMask::MODIFY) + .map_err(|e| anyhow!("Cannot watch certificate file: {e}"))?; + let key_watch = inotify + .watches() + .add(key_file.clone(), inotify::WatchMask::MODIFY) + .map_err(|e| anyhow!("Cannot watch key file: {e}"))?; + + let buffer = [0; 1024]; + let stream = inotify + .into_event_stream(buffer) + .map_err(|e| anyhow!("Cannot create inotify event stream: {e}"))?; + + tokio::spawn(async move { + tokio::pin!(stream); + let mut cert_changed = false; + let mut key_changed = false; + + while let Some(event) = stream.next().await { + let event = match event { + Ok(event) => event, + Err(e) => { + warn!("Cannot read inotify event: {e}"); + continue; + } + }; + + if event.wd == cert_watch { + info!("TLS certificate file has been modified"); + cert_changed = true; + } + if event.wd == key_watch { + info!("TLS key file has been modified"); + key_changed = true; + } + + if key_changed && cert_changed { + info!("reloading TLS certificate"); + + cert_changed = false; + key_changed = false; + reloadable_rust_config + .reload_from_pem_file(cert_file.clone(), key_file.clone()) + .await + .expect("Cannot reload TLS certificate"); // we want to panic here + } + } + }); + + Ok(rust_config) +} + fn precompile_policies( engine: &wasmtime::Engine, fetched_policies: &FetchedPolicies, diff --git a/tests/integration_test.rs b/tests/integration_test.rs index ceb2c0b0..c85af277 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -519,3 +519,163 @@ async fn test_policy_with_wrong_url() { assert_eq!(status.code, Some(500)); assert!(pattern.is_match(&status.message.unwrap())); } + +// helper functions for certificate rotation test, which is a feature supported only on Linux +#[cfg(target_os = "linux")] +mod certificate_reload_helpers { + use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; + use rcgen::{generate_simple_self_signed, CertifiedKey}; + use std::net::TcpStream; + + pub struct TlsData { + pub key: String, + pub cert: String, + } + + pub fn create_cert(hostname: &str) -> TlsData { + let subject_alt_names = vec![hostname.to_string()]; + + let CertifiedKey { cert, key_pair } = + generate_simple_self_signed(subject_alt_names).unwrap(); + + TlsData { + key: key_pair.serialize_pem(), + cert: cert.pem(), + } + } + + pub async fn get_tls_san_names(domain_ip: &str, domain_port: &str) -> Vec { + let domain_ip = domain_ip.to_string(); + let domain_port = domain_port.to_string(); + + tokio::task::spawn_blocking(move || { + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_verify(SslVerifyMode::NONE); + let connector = builder.build(); + let stream = TcpStream::connect(format!("{domain_ip}:{domain_port}")).unwrap(); + let stream = connector.connect(&domain_ip, stream).unwrap(); + + let cert = stream.ssl().peer_certificate().unwrap(); + cert.subject_alt_names() + .expect("failed to get SAN names") + .iter() + .map(|name| { + name.dnsname() + .expect("failed to get DNS name from SAN entry") + .to_string() + }) + .collect::>() + }) + .await + .unwrap() + } + + pub async fn check_tls_san_name(domain_ip: &str, domain_port: &str, hostname: &str) -> bool { + let sleep_interval = std::time::Duration::from_secs(1); + let max_retries = 10; + let mut failed_retries = 0; + let hostname = hostname.to_string(); + loop { + let san_names = get_tls_san_names(domain_ip, domain_port).await; + if san_names.contains(&hostname) { + return true; + } + failed_retries += 1; + if failed_retries >= max_retries { + return false; + } + tokio::time::sleep(sleep_interval).await; + } + } + + pub async fn wait_for_policy_server_to_be_ready(address: &str) { + let sleep_interval = std::time::Duration::from_secs(1); + let max_retries = 5; + let mut failed_retries = 0; + + // wait for the server to start + let client = reqwest::Client::builder() + .danger_accept_invalid_certs(true) + .build() + .unwrap(); + + loop { + let url = reqwest::Url::parse(&format!("https://{address}/readiness")).unwrap(); + match client.get(url).send().await { + Ok(_) => break, + Err(e) => { + failed_retries += 1; + if failed_retries >= max_retries { + panic!("failed to start the server: {:?}", e); + } + tokio::time::sleep(sleep_interval).await; + } + } + } + } +} + +#[cfg(target_os = "linux")] +#[tokio::test(flavor = "multi_thread")] +async fn test_detect_certificate_rotation() { + use certificate_reload_helpers::*; + + let certs_dir = tempfile::tempdir().unwrap(); + let cert_file = certs_dir.path().join("policy-server.pem"); + let key_file = certs_dir.path().join("policy-server-key.pem"); + + let hostname1 = "cert1.example.com"; + let tls_data1 = create_cert(&hostname1); + + std::fs::write(&cert_file, tls_data1.cert).unwrap(); + std::fs::write(&key_file, tls_data1.key).unwrap(); + + let mut config = default_test_config(); + config.tls_config = Some(policy_server::config::TlsConfig { + cert_file: cert_file.to_str().unwrap().to_string(), + key_file: key_file.to_str().unwrap().to_string(), + }); + config.policies = HashMap::new(); + + let domain_ip = config.addr.ip().to_string(); + let domain_port = config.addr.port().to_string(); + + tokio::spawn(async move { + policy_server::tracing::setup_tracing( + &config.log_level, + &config.log_fmt, + config.log_no_color, + ) + .unwrap(); + let api_server = policy_server::PolicyServer::new_from_config(config) + .await + .unwrap(); + api_server.run().await.unwrap(); + }); + wait_for_policy_server_to_be_ready(format!("{domain_ip}:{domain_port}").as_str()).await; + + assert!(check_tls_san_name(&domain_ip, &domain_port, hostname1).await); + + // Generate a new certificate and key, and switch to them + + let hostname2 = "cert2.example.com"; + let tls_data2 = create_cert(hostname2); + + // write only the cert file + std::fs::write(&cert_file, tls_data2.cert).unwrap(); + + // give inotify some time to ensure it detected the cert change + tokio::time::sleep(std::time::Duration::from_secs(4)).await; + + // the old certificate should still be in use, since we didn't change also the key + assert!(check_tls_san_name(&domain_ip, &domain_port, hostname1).await); + + // write only the cert file + std::fs::write(&key_file, tls_data2.key).unwrap(); + + // give inotify some time to ensure it detected the cert change + tokio::time::sleep(std::time::Duration::from_secs(4)).await; + + // the new certificate should be in use + assert!(check_tls_san_name(&domain_ip, &domain_port, hostname2).await); +}