From 98efe3ae5652247c5b1b296dcc189d298cd96ac4 Mon Sep 17 00:00:00 2001 From: Aravinth Manivannan Date: Mon, 16 Oct 2023 18:29:23 +0530 Subject: [PATCH] feat: manually read configuration from environment variables --- config/default.toml | 3 +- src/settings.rs | 358 +++++++++++++++++++++++++++++++++----------- src/tests/mod.rs | 4 +- 3 files changed, 271 insertions(+), 94 deletions(-) diff --git a/config/default.toml b/config/default.toml index 43583e64..fb06e174 100644 --- a/config/default.toml +++ b/config/default.toml @@ -45,9 +45,8 @@ duration = 30 # cooldown period in seconds # url = "postgres://batman:password@batcave.org:5432/batcave" # database_type = "postgres" # pool = 4 -url = "http://example.org" # hack for tests to run successfully +url = "postgres://example.org" # hack for tests to run successfully pool = 4 -database_type="postgres" # "postgres", "maria" [redis] # This section deals with the database location and how to access it diff --git a/src/settings.rs b/src/settings.rs index 5eaa66d2..b70aec1a 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -6,23 +6,25 @@ use std::path::Path; use std::{env, fs}; -use config::{Config, ConfigError, Environment, File, ConfigBuilder}; +use config::builder::DefaultState; +use config::{Config, ConfigBuilder, ConfigError, File}; use derive_more::Display; use serde::{Deserialize, Serialize}; use url::Url; -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Eq, PartialEq)] pub struct Server { pub port: u32, pub domain: String, pub cookie_secret: String, pub ip: String, + // TODO: remove pub url_prefix: Option, pub proxy_has_tls: bool, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Eq, PartialEq)] pub struct Captcha { pub salt: String, pub gc: u64, @@ -32,7 +34,7 @@ pub struct Captcha { pub default_difficulty_strategy: DefaultDifficultyStrategy, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Eq, PartialEq)] pub struct DefaultDifficultyStrategy { pub avg_traffic_difficulty: u32, pub broke_my_site_traffic_difficulty: u32, @@ -40,7 +42,7 @@ pub struct DefaultDifficultyStrategy { pub duration: u32, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Eq, PartialEq)] pub struct Smtp { pub from: String, pub reply: String, @@ -57,7 +59,7 @@ impl Server { } } -#[derive(Deserialize, Serialize, Display, PartialEq, Clone, Debug)] +#[derive(Deserialize, Serialize, Display, Eq, PartialEq, Clone, Debug)] #[serde(rename_all = "lowercase")] pub enum DBType { #[display(fmt = "postgres")] @@ -76,33 +78,84 @@ impl DBType { } } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Eq, PartialEq)] pub struct Database { pub url: String, pub pool: u32, pub database_type: DBType, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Eq, PartialEq)] pub struct Redis { pub url: String, pub pool: u32, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Eq, PartialEq)] pub struct Settings { pub debug: bool, pub commercial: bool, + pub source_code: String, + pub allow_registration: bool, + pub allow_demo: bool, pub database: Database, pub redis: Option, pub server: Server, pub captcha: Captcha, - pub source_code: String, pub smtp: Option, - pub allow_registration: bool, - pub allow_demo: bool, } +const ENV_VAR_CONFIG: [(&str, &str); 29] = [ + /* top-level */ + ("debug", "MCAPTCHA_debug"), + ("commercial", "MCAPTCHA_commercial"), + ("source_code", "MCAPTCHA_source_code"), + ("allow_registration", "MCAPTCHA_allow_registration"), + ("allow_demo", "MCAPTCHA_allow_demo"), + + /* database */ + ("database.url", "DATABASE_URL"), + ("database.pool", "MCAPTCHA_database_POOL"), + + /* redis */ + ("redis.url", "MCPATCHA_redis_URL"), + ("redis.pool", "MCPATCHA_redis_POOL"), + + /* server */ + ("server.port", "PORT"), + ("server.domain", "MCAPTCHA_server_DOMAIN"), + ("server.cookie_secret", "MCAPTCHA__server_COOKIE_SECRET"), + ("server.ip", "MCAPTCHA__server_IP"), + ("server.proxy_has_tls", "MCAPTCHA__server_PROXY_HAS_TLS"), + + + /* captcha */ + ("captcha.salt", "MCAPTCHA_captcha_SALT"), + ("captcha.gc", "MCAPTCHA_captcha_GC"), + ("captcha.runners", "MCAPTCHA_captcha_RUNNERS"), + ("captcha.queue_length", "MCAPTCHA_captcha_QUEUE_LENGTH"), + ("captcha.enable_stats", "MCAPTCHA_captcha_ENABLE_STATS"), + ("captcha.default_difficulty_strategy.avg_traffic_difficulty", "MCAPTCHA_captcha_DEFAULT_DIFFICULTY_STRATEGY_avg_traffic_difficulty"), + ("captcha.default_difficulty_strategy.broke_my_site_traffic_difficulty", "MCAPTCHA_captcha_DEFAULT_DIFFICULTY_STRATEGY_broke_my_site_traffic_difficulty"), + ("captcha.default_difficulty_strategy.peak_sustainable_traffic_difficulty", + "MCAPTCHA_captcha_DEFAULT_DIFFICULTY_STRATEGY_peak_sustainable_traffic_difficulty"), + ( "captcha.default_difficulty_strategy.duration", + "MCAPTCHA_captcha_DEFAULT_DIFFICULTY_STRATEGY_duration" + ), + + + /* SMTP */ + ("smtp.from", "MCPATCHA_smtp_FROM"), + ("smtp.reply", "MCPATCHA_smtp_REPLY"), + ("smtp.url", "MCPATCHA_smtp_URL"), + ("smtp.username", "MCPATCHA_smtp_USERNAME"), + ("smtp.password", "MCPATCHA_smtp_PASSWORD"), + ("smtp.port", "MCPATCHA_smtp_PORT"), + + + +]; + #[cfg(not(tarpaulin_include))] impl Settings { pub fn new() -> Result { @@ -111,9 +164,19 @@ impl Settings { const CURRENT_DIR: &str = "./config/default.toml"; const ETC: &str = "/etc/mcaptcha/config.toml"; - s = s.set_default("capatcha.enable_stats", true.to_string()) + s = s + .set_default("capatcha.enable_stats", true.to_string()) .expect("unable to set capatcha.enable_stats default config"); + // Will be overridden after config is parsed and loaded into Settings by + // Settings::set_database_type. + // This parameter is not ergonomic for users, but it is required and can be programatically + // inferred. But we need a default value for config lib to parse successfully, since it is + // DBType and not Option + s = s + .set_default("database.database_type", DBType::Postgres.to_string()) + .expect("unable to set database.database_type default config"); + if let Ok(path) = env::var("MCAPTCHA_CONFIG") { let absolute_path = Path::new(&path).canonicalize().unwrap(); log::info!( @@ -136,93 +199,208 @@ impl Settings { log::warn!("Configuration file not found"); } - s = s.add_source(Environment::with_prefix("MCAPTCHA").separator("_")); - - - if let Ok(val) = env::var("PORT") { - s = s.set_override("server.port", val).unwrap(); - log::info!("Overriding [server].port with environment variable"); - } - - match env::var("DATABASE_URL") { - Ok(val) => { - let url = Url::parse(&val).expect("couldn't parse Database URL"); - s = s.set_override("database.url", url.to_string()).unwrap(); - let database_type = DBType::from_url(&url).unwrap(); - s = s.set_override("database.database_type", database_type.to_string()) - .unwrap(); - log::info!("Overriding [database].url and [database].database_type with environment variable"); - } - Err(e) => { - log::warn!("Couldn't interpret DATABASE_URL: {:?}", e); - } - } + s = Self::env_override(s); let mut settings = s.build()?.try_deserialize::()?; settings.check_url(); - #[cfg(test)] - settings.set_test_defaults(); + settings.set_database_type(); Ok(settings) - } - // bypass for issue #15701 for more information - fn set_test_defaults(&mut self) { - self.database.pool = 2; + fn env_override(mut s: ConfigBuilder) -> ConfigBuilder { + for (parameter, env_var_name) in ENV_VAR_CONFIG.iter() { + if let Ok(val) = env::var(env_var_name) { + log::debug!( + "Overriding [{parameter}] with environment variable {env_var_name}" + ); + s = s.set_override(parameter, val).unwrap(); + } + } + + s } -#[cfg(not(tarpaulin_include))] -fn check_url(&self) { - Url::parse(&self.source_code).expect("Please enter a URL for source_code in settings"); -} + fn set_database_type(&mut self) { + let url = Url::parse(&self.database.url) + .expect("couldn't parse Database URL and detect database type"); + self.database.database_type = DBType::from_url(&url).unwrap(); + } + + fn check_url(&self) { + Url::parse(&self.source_code) + .expect("Please enter a URL for source_code in settings"); + } } -#[cfg(not(tarpaulin_include))] -fn set_database_url(s: &mut Config) { - s.set( - "database.url", - format!( - r"postgres://{}:{}@{}:{}/{}", - s.get::("database.username") - .expect("Couldn't access database username"), - urlencoding::encode( - s.get::("database.password") - .expect("Couldn't access database password") - .as_str() - ), - s.get::("database.hostname") - .expect("Couldn't access database hostname"), - s.get::("database.port") - .expect("Couldn't access database port"), - s.get::("database.name") - .expect("Couldn't access database name") - ), - ) - .expect("Couldn't set database url"); -} +#[cfg(test)] +mod tests { -//#[cfg(test)] -//mod tests { -// use super::*; -// -// #[test] -// fn url_prefix_test() { -// let mut settings = Settings::new().unwrap(); -// assert!(settings.server.url_prefix.is_none()); -// settings.server.url_prefix = Some("test".into()); -// settings.server.check_url_prefix(); -// settings.server.url_prefix = Some(" ".into()); -// settings.server.check_url_prefix(); -// assert!(settings.server.url_prefix.is_none()); -// } -// -// #[test] -// fn smtp_config_works() { -// let settings = Settings::new().unwrap(); -// assert!(settings.smtp.is_some()); -// assert_eq!(settings.smtp.as_ref().unwrap().password, "password"); -// assert_eq!(settings.smtp.as_ref().unwrap().username, "admin"); -// } -//} + use super::*; + + #[test] + fn env_override_works() { + use crate::tests::get_settings; + let init_settings = get_settings(); + // so that it can be tested outside the macro (helper) too + let mut new_settings; + + macro_rules! helper { + + + + + ($env:expr, $val:expr, $val_typed:expr, $($param:ident).+) => { + println!("Setting env var {} to {} for test", $env, $val); + env::set_var($env, $val); + new_settings = get_settings(); + assert_eq!(new_settings.$($param).+, $val_typed); + assert_ne!(new_settings.$($param).+, init_settings.$($param).+); + env::remove_var($env); + }; + + + ($env:expr, $val:expr, $($param:ident).+) => { + helper!($env, $val.to_string(), $val, $($param).+); + }; + } + + /* top level */ + helper!("MCAPTCHA_debug", false, debug); + helper!("MCAPTCHA_commercial", true, commercial); + helper!("MCAPTCHA_allow_registration", false, allow_registration); + helper!("MCAPTCHA_allow_demo", false, allow_demo); + + /* database_type */ + + helper!( + "DATABASE_URL", + "postgres://postgres:password@localhost:5432/postgres", + database.url + ); + assert_eq!(new_settings.database.database_type, DBType::Postgres); + helper!( + "DATABASE_URL", + "mysql://maria:password@localhost/maria", + database.url + ); + assert_eq!(new_settings.database.database_type, DBType::Maria); + helper!("MCAPTCHA_database_POOL", 1000, database.pool); + + /* redis */ + + /* redis.url */ + let env = "MCPATCHA_redis_URL"; + let val = "redis://redis.example.org"; + println!("Setting env var {} to {} for test", env, val); + env::set_var(env, val.to_string()); + new_settings = get_settings(); + assert_eq!(new_settings.redis.as_ref().unwrap().url, val); + assert_ne!( + new_settings.redis.as_ref().unwrap().url, + init_settings.redis.as_ref().unwrap().url + ); + env::remove_var(env); + + /* redis.pool */ + let env = "MCPATCHA_redis_POOL"; + let val = 999; + println!("Setting env var {} to {} for test", env, val); + env::set_var(env, val.to_string()); + new_settings = get_settings(); + assert_eq!(new_settings.redis.as_ref().unwrap().pool, val); + assert_ne!( + new_settings.redis.as_ref().unwrap().pool, + init_settings.redis.as_ref().unwrap().pool + ); + env::remove_var(env); + + helper!("PORT", 0, server.port); + helper!("MCAPTCHA_server_DOMAIN", "example.org", server.domain); + helper!( + "MCAPTCHA__server_COOKIE_SECRET", + "dafasdfsdf", + server.cookie_secret + ); + helper!("MCAPTCHA__server_IP", "9.9.9.9", server.ip); + helper!("MCAPTCHA__server_PROXY_HAS_TLS", true, server.proxy_has_tls); + + /* captcha */ + + helper!("MCAPTCHA_captcha_SALT", "foobarasdfasdf", captcha.salt); + helper!("MCAPTCHA_captcha_GC", 500, captcha.gc); + helper!( + "MCAPTCHA_captcha_RUNNERS", + "500", + Some(500), + captcha.runners + ); + + helper!("MCAPTCHA_captcha_QUEUE_LENGTH", 500, captcha.queue_length); + helper!("MCAPTCHA_captcha_ENABLE_STATS", false, captcha.enable_stats); + helper!( + "MCAPTCHA_captcha_DEFAULT_DIFFICULTY_STRATEGY_avg_traffic_difficulty", + 999, + captcha.default_difficulty_strategy.avg_traffic_difficulty + ); + helper!("MCAPTCHA_captcha_DEFAULT_DIFFICULTY_STRATEGY_peak_sustainable_traffic_difficulty", 999 , captcha.default_difficulty_strategy.peak_sustainable_traffic_difficulty); + helper!("MCAPTCHA_captcha_DEFAULT_DIFFICULTY_STRATEGY_broke_my_site_traffic_difficulty", 999 , captcha.default_difficulty_strategy.broke_my_site_traffic_difficulty); + helper!( + "MCAPTCHA_captcha_DEFAULT_DIFFICULTY_STRATEGY_duration", + 999, + captcha.default_difficulty_strategy.duration + ); + + /* SMTP */ + + let vals = [ + "MCPATCHA_smtp_FROM", + "MCPATCHA_smtp_REPLY", + "MCPATCHA_smtp_URL", + "MCPATCHA_smtp_USERNAME", + "MCPATCHA_smtp_PASSWORD", + "MCPATCHA_smtp_PORT", + ]; + for env in vals.iter() { + println!("Setting env var {} to {} for test", env, env); + env::set_var(env, env.to_string()); + } + + let port = 9999; + env::set_var("MCPATCHA_smtp_PORT", port.to_string()); + + new_settings = get_settings(); + let smtp_new = new_settings.smtp.as_ref().unwrap(); + let smtp_old = init_settings.smtp.as_ref().unwrap(); + assert_eq!(smtp_new.from, "MCPATCHA_smtp_FROM"); + assert_eq!(smtp_new.reply, "MCPATCHA_smtp_REPLY"); + assert_eq!(smtp_new.username, "MCPATCHA_smtp_USERNAME"); + assert_eq!(smtp_new.password, "MCPATCHA_smtp_PASSWORD"); + assert_eq!(smtp_new.port, port); + assert_ne!(smtp_new, smtp_old); + + for env in vals.iter() { + env::remove_var(env); + } + } + + // #[test] + // fn url_prefix_test() { + // let mut settings = Settings::new().unwrap(); + // assert!(settings.server.url_prefix.is_none()); + // settings.server.url_prefix = Some("test".into()); + // settings.server.check_url_prefix(); + // settings.server.url_prefix = Some(" ".into()); + // settings.server.check_url_prefix(); + // assert!(settings.server.url_prefix.is_none()); + // } + // + // #[test] + // fn smtp_config_works() { + // let settings = Settings::new().unwrap(); + // assert!(settings.smtp.is_some()); + // assert_eq!(settings.smtp.as_ref().unwrap().password, "password"); + // assert_eq!(settings.smtp.as_ref().unwrap().username, "admin"); + // } +} diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 5234be7a..224200a9 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -41,7 +41,7 @@ pub mod pg { settings.database.url = url.clone(); settings.database.database_type = DBType::Postgres; settings.database.pool = 2; - + Data::new(&settings).await } } @@ -61,7 +61,7 @@ pub mod maria { settings.database.url = url.clone(); settings.database.database_type = DBType::Maria; settings.database.pool = 2; - + Data::new(&settings).await } }