diff --git a/include/Proxy.h b/include/Proxy.h index 72c0033..ee5d76a 100644 --- a/include/Proxy.h +++ b/include/Proxy.h @@ -3,19 +3,21 @@ #include #include #include +#include "Socket.h" #define SMTP_STATE_WAIT_FOR_HELO 0 #define SMTP_STATE_WAIT_FOR_MAILFROM 1 #define SMTP_STATE_WAIT_FOR_RCPTTO 2 #define SMTP_STATE_WAIT_FOR_DATA 3 -class Proxy -{ +class Proxy { public: Proxy(); - void run(boost::asio::ssl::stream* outside, const std::string& peer_address); + void setOutside(Socket& socket); + void run(const std::string& peer_address); private: - boost::asio::io_service& io_service_; + boost::asio::io_context io_context_; boost::asio::ssl::context ssl_context; + Socket* outside_socket_; }; diff --git a/include/Utils.h b/include/Utils.h index ad05f1f..9a3fa9f 100644 --- a/include/Utils.h +++ b/include/Utils.h @@ -22,6 +22,7 @@ #include "hermes.h" #include +#include #include #include #include @@ -35,7 +36,9 @@ #include "Database.h" #include "Socket.h" -using namespace std; +using std::string; +using std::stringstream; +using std::list; #ifdef WIN32 #define sleep(x) Sleep(1000*(x)) @@ -52,37 +55,37 @@ class Utils { public: //string utilities - static string strtolower(string); - static string trim(string); + static string strtolower(std::string_view); + static string trim(std::string_view); static string inttostr(int); static string ulongtostr(unsigned long); //email-related utilities - static string getmail(string&); - static string getdomain(string&); - static string reverseip(string&); + static string getmail(const string&); + static string getdomain(const string&); + static string reverseip(const string&); //spam-related utilities (TODO: move to a different class) - static bool greylist(string,string&,string&,string&); - static bool listed_on_dns_lists(list&,unsigned char,string&); - static bool whitelisted(string,string&); - static bool blacklisted(string,string&,string&); + static bool greylist(const string& dbfile, string& ip, string& p_from, string& p_to); + static bool listed_on_dns_lists(const list& dns_domains, unsigned char percentage, const string& ip); + static bool whitelisted(const string& dbfile, string& ip); + static bool blacklisted(const string& dbfile, string& ip, string& to); #ifndef WIN32 //posix-utils - static int usertouid(string); - static int grouptogid(string); + static int usertouid(const string& user); + static int grouptogid(const string& groupname); #endif //WIN32 //misc - static string get_canonical_filename(string); - static bool file_exists(string); - static bool dir_exists(string); + static string get_canonical_filename(const string& file); + static bool file_exists(const string& file); + static bool dir_exists(const string& dir); static string errnotostrerror(int); static string rfc2821_date(time_t *timestamp=NULL); static string gethostname(); - static void write_pid(string,pid_t); - static string gethostname(int s); + static void write_pid(const string& file, pid_t pid); + static string gethostname(int socket); }; #endif //UTILS_H diff --git a/src/Proxy.cpp b/src/Proxy.cpp index 174b4a4..415a7a1 100644 --- a/src/Proxy.cpp +++ b/src/Proxy.cpp @@ -2,20 +2,53 @@ #include "Proxy.h" #include #include -#include +#include +#include #include +#include #include "Utils.h" #include "Configfile.h" extern Configfile cfg; -void Proxy::run(boost::asio::ssl::stream* outside, const std::string& peer_address) { - boost::asio::ssl::stream* inside; +namespace { +std::string resolveHostname(const std::string& ip_address) { + try { + boost::asio::io_context io_context; + boost::asio::ip::tcp::resolver resolver(io_context); + boost::asio::ip::address addr = boost::asio::ip::make_address(ip_address); + boost::asio::ip::tcp::endpoint ep(addr, 0); + auto results = resolver.resolve(ep.address().to_string(), ""); + if (results.begin() != results.end()) { + return results.begin()->host_name(); + } + } catch (const std::exception&) {} + return ""; +} +} + +Proxy::Proxy() + : io_context_(), + ssl_context(boost::asio::ssl::context::tlsv12), + outside_socket_(nullptr) { +} + +void Proxy::setOutside(Socket& socket) { + outside_socket_ = &socket; +} + +void Proxy::run(const std::string& peer_address) { + if (!outside_socket_) { + throw std::runtime_error("Outside socket not set"); + } + + boost::asio::ssl::stream inside(io_context_, ssl_context); // Original comments and variables retained - std::string from = ""; - std::string to = ""; - std::string ehlostr = ""; - std::string resolvedname = ""; + const std::string empty_str; + std::string from = empty_str; + std::string to = empty_str; + std::string ehlostr = empty_str; + std::string resolvedname = empty_str; unsigned char last_state = SMTP_STATE_WAIT_FOR_HELO; long unimplemented_requests = 0; @@ -25,13 +58,7 @@ void Proxy::run(boost::asio::ssl::stream* outside, try { // Resolve hostname using the Boost resolver - try { - resolvedname = HostnameResolver::resolveHostname(peer_address); - } - catch (const std::exception& e) { - std::cerr << std::format("Hostname resolution error: {}", e.what()) << std::endl; - resolvedname = ""; - } + resolvedname = resolveHostname(peer_address); // Configure SSL contexts boost::asio::ssl::context ssl_context(boost::asio::ssl::context::tlsv12); @@ -53,7 +80,8 @@ void Proxy::run(boost::asio::ssl::stream* outside, } } - if (cfg.getWhitelistedDisablesEverything() && Utils::whitelisted(cfg.getDatabaseFile(), peer_address)) { + std::string peer_addr_copy = peer_address; + if (cfg.getWhitelistedDisablesEverything() && Utils::whitelisted(cfg.getDatabaseFile(), peer_addr_copy)) { throttled = false; authenticated = true; } else { @@ -61,22 +89,20 @@ void Proxy::run(boost::asio::ssl::stream* outside, std::this_thread::sleep_for(std::chrono::seconds(cfg.getBannerDelayTime())); // Check if data is waiting before server banner - boost::system::error_code ec; - size_t available = outside->lowest_layer().available(ec); - if (ec || available > 0) { - std::cout << std::format("421 (data_before_banner) (ip:{})\n", peer_address); + if (outside_socket_->canRead(0.0)) { + std::cout << fmt::format("421 (data_before_banner) (ip:{})\n", peer_address); std::this_thread::sleep_for(std::chrono::seconds(20)); // Write rejection message - std::string rejection_msg = "421 Stop sending data before we show you the banner\r\n"; - boost::asio::write(outside->lowest_layer(), boost::asio::buffer(rejection_msg), ec); + std::string rejection_msg = fmt::format("421 Stop sending data before we show you the banner\r\n"); + outside_socket_->writeBytes(const_cast(rejection_msg.c_str()), rejection_msg.length()); return; } } } // Connect to the inside server - boost::asio::ip::tcp::resolver resolver(io_service_); + boost::asio::ip::tcp::resolver resolver(io_context_); boost::asio::ip::tcp::resolver::results_type endpoints = resolver.resolve(cfg.getServerHost(), std::to_string(cfg.getServerPort())); @@ -89,159 +115,151 @@ void Proxy::run(boost::asio::ssl::stream* outside, } if (cfg.getIncomingSsl()) { - outside->set_verify_mode(boost::asio::ssl::verify_none); - outside->handshake(boost::asio::ssl::stream_base::server); + #ifdef HAVE_SSL + outside_socket_->prepareSSL(true); + outside_socket_->startSSL(true); + #endif } // Communication buffers - std::vector read_buffer(4096); - boost::system::error_code ec; + char read_buffer[4096]; + ssize_t bytes_read; // Main loop for communication - while (!outside->lowest_layer().is_open() || !inside.lowest_layer().is_open()) { + while (!outside_socket_->isClosed() && !inside.lowest_layer().is_open()) { // Check if the client wants to send something to the server - size_t client_available = outside->lowest_layer().available(ec); - if (client_available > 0) { - size_t bytes_read = outside->read_some(boost::asio::buffer(read_buffer), ec); - strtemp = std::string(read_buffer.begin(), read_buffer.begin() + bytes_read); + if (outside_socket_->canRead(1.0)) { + bytes_read = outside_socket_->readBytes(read_buffer, sizeof(read_buffer)); + if (bytes_read > 0) { + strtemp = std::string(read_buffer, bytes_read); - if (strtemp.length() > 10 && "mail from:" == Utils::strtolower(strtemp.substr(0, 10))) { - from = Utils::getmail(strtemp); - last_state = SMTP_STATE_WAIT_FOR_RCPTTO; + std::string cmd_lower = Utils::strtolower(strtemp.substr(0, std::min(10, strtemp.length()))); + if (strtemp.length() > 10 && "mail from:" == cmd_lower) { + from = std::string(Utils::getmail(strtemp)); + last_state = SMTP_STATE_WAIT_FOR_RCPTTO; + } + + cmd_lower = Utils::strtolower(strtemp.substr(0, std::min(4, strtemp.length()))); + if ("ehlo" == cmd_lower) esmtp = true; + + if (strtemp.length() > 4 && ("ehlo" == cmd_lower || + "helo" == cmd_lower)) { + ehlostr = std::string(Utils::trim(strtemp.substr(5))); + last_state = SMTP_STATE_WAIT_FOR_MAILFROM; + } + + // RCPT TO handling with comprehensive checks + cmd_lower = Utils::strtolower(strtemp.substr(0, std::min(8, strtemp.length()))); + if (strtemp.length() > 8 && "rcpt to:" == cmd_lower) { + std::string mechanism; + std::string message; + to = std::string(Utils::getmail(strtemp)); + + // Construct log string + std::string strlog = fmt::format("from {} (ip:{}, hostname:{}, {} {}) -> to {}", + from, peer_address, resolvedname, (esmtp ? "ehlo" : "helo"), ehlostr, to); + + // Greylisting check + std::string code = "250"; + std::string peer_addr_copy = peer_address; + std::string from_copy = from; + std::string to_copy = to; + if (cfg.getGreylist() && !authenticated && + Utils::greylist(cfg.getDatabaseFile(), peer_addr_copy, from_copy, to_copy)) { + code = "421"; + mechanism = "greylist"; + message = fmt::format("{} Greylisted!! Please try again in a few minutes.", code); + std::cout << fmt::format("checking {}\n", mechanism); + } + // SPF Check + #ifdef HAVE_SPF + else if (cfg.getQuerySpf() && !authenticated && + !spf_checker.query(peer_address, ehlostr, from)) { + code = cfg.getAddStatusHeader() ? "250" : + (cfg.getReturnTempErrorOnReject() ? "421" : "550"); + mechanism = "spf"; + message = fmt::format("{} You do not seem to be allowed to send email for that particular domain.", code); + std::cout << fmt::format("checking {}\n", mechanism); + } + #endif + // Blacklist check + else if (!authenticated && + Utils::blacklisted(cfg.getDatabaseFile(), peer_addr_copy, to_copy)) { + code = cfg.getReturnTempErrorOnReject() ? "421" : "550"; + mechanism = "allowed-domain-per-ip"; + message = fmt::format("{} You do not seem to be allowed to send email to that particular domain from that address.", code); + std::cout << fmt::format("checking {}\n", mechanism); + } + // DNS Blacklist check + else if (!cfg.getDnsBlacklistDomains().empty() && !authenticated && + Utils::listed_on_dns_lists(cfg.getDnsBlacklistDomains(), + cfg.getDnsBlacklistPercentage(), + peer_address)) { + code = cfg.getAddStatusHeader() ? "250" : + (cfg.getReturnTempErrorOnReject() ? "421" : "550"); + mechanism = "dnsbl"; + message = fmt::format("{} You are listed on some DNS blacklists. Get delisted before trying to send us email.", code); + std::cout << fmt::format("checking {}\n", mechanism); + } + // Reverse DNS check + else if (cfg.getRejectNoReverseResolution() && !authenticated && resolvedname.empty()) { + code = cfg.getReturnTempErrorOnReject() ? "421" : "550"; + mechanism = "no reverse resolution"; + message = fmt::format("{} Your IP address does not resolve to a hostname.", code); + std::cout << fmt::format("checking {}\n", mechanism); + } + // HELO/Reverse name check + else if (cfg.getCheckHeloAgainstReverse() && !authenticated && ehlostr != resolvedname) { + code = cfg.getReturnTempErrorOnReject() ? "421" : "550"; + mechanism = "helo differs from resolved name"; + message = fmt::format("{} Your IP hostname doesn't match your envelope hostname.", code); + std::cout << fmt::format("checking {}\n", mechanism); + } + + // Prepare log message + std::string log_str = strlog; + if (!mechanism.empty()) { + log_str = fmt::format("({}) {}", mechanism, log_str); + } + std::cout << fmt::format("{} {}\n", code, log_str); + + // Handle rejection + if (code != "250") { + // Close inside connection + inside.lowest_layer().close(); + + // Delay to annoy spammers + std::this_thread::sleep_for(std::chrono::seconds(20)); + + // Send rejection message + std::string rejection_msg = fmt::format("{}\r\n", message); + outside_socket_->writeBytes(const_cast(rejection_msg.c_str()), rejection_msg.length()); + return; + } + + last_state = SMTP_STATE_WAIT_FOR_DATA; + } + + // Send to inside server + boost::asio::async_write(inside, boost::asio::buffer(strtemp), + [](const boost::system::error_code& /*error*/, std::size_t /*bytes_transferred*/) {}); } - - if ("ehlo" == Utils::strtolower(strtemp.substr(0, 4))) esmtp = true; - - if (strtemp.length() > 4 && ("ehlo" == Utils::strtolower(strtemp.substr(0, 4)) || - "helo" == Utils::strtolower(strtemp.substr(0, 4)))) { - ehlostr = Utils::trim(strtemp.substr(5)); - last_state = SMTP_STATE_WAIT_FOR_MAILFROM; - } - - // RCPT TO handling with comprehensive checks - if (strtemp.length() > 8 && "rcpt to:" == Utils::strtolower(strtemp.substr(0, 8))) { - std::string mechanism = ""; - std::string message = ""; - to = Utils::getmail(strtemp); - - // Construct log string using std::format - std::string strlog = std::format( - "from {} (ip:{}, hostname:{}, {} {}:{}) -> to {}", - from, - peer_address, - resolvedname, - (esmtp ? "ehlo" : "helo"), - ehlostr, - to - ); - - // Greylisting check - std::string code = "250"; - if (cfg.getGreylist() && !authenticated && - Utils::greylist(cfg.getDatabaseFile(), peer_address, from, to)) { - code = "421"; - mechanism = "greylist"; - message = std::format("{} Greylisted!! Please try again in a few minutes.", code); - std::cout << std::format("checking {}\n", mechanism); - } - // SPF Check - #ifdef HAVE_SPF - else if (cfg.getQuerySpf() && !authenticated && - !spf_checker.query(peer_address, ehlostr, from)) { - code = cfg.getAddStatusHeader() ? "250" : - (cfg.getReturnTempErrorOnReject() ? "421" : "550"); - mechanism = "spf"; - message = std::format( - "{} You do not seem to be allowed to send email for that particular domain.", - code - ); - std::cout << std::format("checking {}\n", mechanism); - } - #endif - // Blacklist check - else if (!authenticated && - Utils::blacklisted(cfg.getDatabaseFile(), peer_address, to)) { - code = cfg.getReturnTempErrorOnReject() ? "421" : "550"; - mechanism = "allowed-domain-per-ip"; - message = std::format( - "{} You do not seem to be allowed to send email to that particular domain from that address.", - code - ); - std::cout << std::format("checking {}\n", mechanism); - } - // DNS Blacklist check - else if (!cfg.getDnsBlacklistDomains().empty() && !authenticated && - Utils::listed_on_dns_lists(cfg.getDnsBlacklistDomains(), - cfg.getDnsBlacklistPercentage(), - peer_address)) { - code = cfg.getAddStatusHeader() ? "250" : - (cfg.getReturnTempErrorOnReject() ? "421" : "550"); - mechanism = "dnsbl"; - message = std::format( - "{} You are listed on some DNS blacklists. Get delisted before trying to send us email.", - code - ); - std::cout << std::format("checking {}\n", mechanism); - } - // Reverse DNS check - else if (cfg.getRejectNoReverseResolution() && !authenticated && resolvedname.empty()) { - code = cfg.getReturnTempErrorOnReject() ? "421" : "550"; - mechanism = "no reverse resolution"; - message = std::format( - "{} Your IP address does not resolve to a hostname.", - code - ); - std::cout << std::format("checking {}\n", mechanism); - } - // HELO/Reverse name check - else if (cfg.getCheckHeloAgainstReverse() && !authenticated && ehlostr != resolvedname) { - code = cfg.getReturnTempErrorOnReject() ? "421" : "550"; - mechanism = "helo differs from resolved name"; - message = std::format( - "{} Your IP hostname doesn't match your envelope hostname.", - code - ); - std::cout << std::format("checking {}\n", mechanism); - } - - // Prepare log message - if (!mechanism.empty()) { - strlog = std::format("({}) {}", mechanism, strlog); - } - strlog = std::format("{} {}", code, strlog); - std::cout << strlog << "\n"; - - // Handle rejection - if (code != "250") { - // Close inside connection - inside.lowest_layer().close(); - - // Delay to annoy spammers - std::this_thread::sleep_for(std::chrono::seconds(20)); - - // Send rejection message - std::string rejection_msg = message + "\r\n"; - boost::asio::write(outside->lowest_layer(), boost::asio::buffer(rejection_msg), ec); - return; - } - - last_state = SMTP_STATE_WAIT_FOR_DATA; - } - - // Send to inside server - boost::asio::write(inside.lowest_layer(), boost::asio::buffer(strtemp), ec); } // Check if the server wants to send something to the client + boost::system::error_code ec; size_t server_available = inside.lowest_layer().available(ec); - if (server_available > 0) { - size_t bytes_read = inside.read_some(boost::asio::buffer(read_buffer), ec); - strtemp = std::string(read_buffer.begin(), read_buffer.begin() + bytes_read); - - // Send to outside socket - boost::asio::write(outside->lowest_layer(), boost::asio::buffer(strtemp), ec); + if (!ec && server_available > 0) { + std::vector server_buffer(server_available); + size_t bytes_read = boost::asio::read(inside, boost::asio::buffer(server_buffer), ec); + if (!ec && bytes_read > 0) { + outside_socket_->writeBytes(server_buffer.data(), bytes_read); + } } + // Run io_context to process async operations + io_context_.poll(); + // Throttling if (throttled) { std::this_thread::sleep_for(std::chrono::seconds(cfg.getThrottlingTime())); @@ -249,9 +267,9 @@ void Proxy::run(boost::asio::ssl::stream* outside, } } catch (const boost::system::system_error& e) { - std::cerr << std::format("Boost.Asio error: {}", e.what()) << std::endl; + std::cerr << fmt::format("Boost.Asio error: {}\n", e.what()); } catch (const std::exception& e) { - std::cerr << std::format("Standard exception: {}", e.what()) << std::endl; + std::cerr << fmt::format("Standard exception: {}\n", e.what()); } } diff --git a/src/Utils.cpp b/src/Utils.cpp index b20d5bb..c553416 100644 --- a/src/Utils.cpp +++ b/src/Utils.cpp @@ -21,6 +21,11 @@ #include #include +#include + +using std::string; +using std::stringstream; +using std::list; extern Configfile cfg; extern LOGGER_CLASS hermes_log; @@ -69,9 +74,8 @@ string Utils::ulongtostr(unsigned long number) */ string Utils::strtolower(const std::string_view s) { - const std::string lower_str = boost::algorithm::to_lower_copy(s); - - return lower_str; + const std::string str(s); + return boost::algorithm::to_lower_copy(str); } /** @@ -84,13 +88,13 @@ string Utils::strtolower(const std::string_view s) */ string Utils::trim(const std::string_view s) { - while(isspace(s[0])) - s.erase(0,1); - - while(isspace(s[s.length()-1])) - s.erase(s.length()-1,1); - - return s; + auto start = s.find_first_not_of(" \t\n\r\f\v"); + if (start == std::string_view::npos) { + return string(); + } + + auto end = s.find_last_not_of(" \t\n\r\f\v"); + return string(s.substr(start, end - start + 1)); } //------------------------ @@ -141,7 +145,7 @@ string Utils::trim(const std::string_view s) * @return whether triplet should get greylisted or not * @todo unify {white,black,grey}list in one function that returns a different constant in each case */ -bool Utils::greylist(string dbfile,string& ip,string& p_from,string& p_to) +bool Utils::greylist(const string& dbfile, string& ip, string& p_from, string& p_to) { string from=Database::cleanString(p_from); string to=Database::cleanString(p_to); @@ -180,7 +184,7 @@ bool Utils::greylist(string dbfile,string& ip,string& p_from,string& p_to) * @return whether ip is whitelisted or not * @todo unify {white,black,grey}list in one function that returns a different constant in each case */ -bool Utils::whitelisted(string dbfile,string& ip) +bool Utils::whitelisted(const string& dbfile, string& ip) { Database db; string hostname; @@ -212,7 +216,7 @@ bool Utils::whitelisted(string dbfile,string& ip) * @return whether ip is whitelisted or not * @todo this should contain all cases when we should reject a connection */ -bool Utils::blacklisted(string dbfile,string& ip,string& to) +bool Utils::blacklisted(const string& dbfile, string& ip, string& to) { Database db; string hostname; @@ -233,7 +237,7 @@ bool Utils::blacklisted(string dbfile,string& ip,string& to) * @return email extracted from rawline * */ -string Utils::getmail(string& rawline) +string Utils::getmail(const string& rawline) { string email; string::size_type start=0,end=0; @@ -279,7 +283,7 @@ string Utils::getmail(string& rawline) * @return domain of email * */ -string Utils::getdomain(string& email) +string Utils::getdomain(const string& email) { if(email.rfind('@')) return trim(email.substr(email.rfind('@')+1)); @@ -306,7 +310,7 @@ string Utils::getdomain(string& email) * @return uid for user * */ -int Utils::usertouid(string user) +int Utils::usertouid(const string& user) { struct passwd *pwd; pwd=getpwnam(user.c_str()); @@ -329,7 +333,7 @@ int Utils::usertouid(string user) * @return gid for groupname * */ -int Utils::grouptogid(string groupname) +int Utils::grouptogid(const string& groupname) { struct group *grp; grp=getgrnam(groupname.c_str()); @@ -352,7 +356,7 @@ int Utils::grouptogid(string groupname) * @return is file readable? * */ -bool Utils::file_exists(string file) +bool Utils::file_exists(const string& file) { FILE *f=fopen(file.c_str(),"r"); if(NULL==f) @@ -365,7 +369,7 @@ bool Utils::file_exists(string file) } #ifdef WIN32 -string Utils::get_canonical_filename(string file) +string Utils::get_canonical_filename(const string& file) { char buffer[MAX_PATH]; @@ -374,7 +378,7 @@ string Utils::get_canonical_filename(string file) return string(buffer); } #else -string Utils::get_canonical_filename(string file) +string Utils::get_canonical_filename(const string& file) { char *buffer=NULL; string result; @@ -386,6 +390,7 @@ string Utils::get_canonical_filename(string file) return result; } #endif //WIN32 + /** * whether a directory is accesible by current process/user * @@ -394,7 +399,7 @@ string Utils::get_canonical_filename(string file) * @return isdir readable? * */ -bool Utils::dir_exists(string dir) +bool Utils::dir_exists(const string& dir) { DIR *d=opendir(dir.c_str()); if(NULL==d) @@ -431,8 +436,6 @@ string Utils::errnotostrerror(int errnum) strerr="Error "; #endif //WIN32 return string(strerr)+" ("+inttostr(errnum)+")"; -// else -// return string("Error "+inttostr(errno)+" retrieving error code for error number ")+inttostr(errnum); } /** @@ -458,7 +461,7 @@ string Utils::errnotostrerror(int errnum) * * @return whether ip is blacklisted or not */ -bool Utils::listed_on_dns_lists(list& dns_domains,unsigned char percentage,string& ip) +bool Utils::listed_on_dns_lists(const list& dns_domains, unsigned char percentage, const string& ip) { string reversedip; unsigned char number_of_lists=dns_domains.size(); @@ -467,7 +470,7 @@ bool Utils::listed_on_dns_lists(list& dns_domains,unsigned char percenta reversedip=reverseip(ip); - for(list::iterator i=dns_domains.begin();i!=dns_domains.end();i++) + for(list::const_iterator i=dns_domains.begin();i!=dns_domains.end();i++) { string dns_domain; @@ -514,7 +517,7 @@ bool Utils::listed_on_dns_lists(list& dns_domains,unsigned char percenta * * @return the reversed ip */ -string Utils::reverseip(string& ip) +string Utils::reverseip(const string& ip) { string inverseip=""; string::size_type pos=0,ppos=0; @@ -628,13 +631,19 @@ string Utils::gethostname() return string(buf); } -string Utils::gethostname(int s) +/** + * Get the hostname for a given socket + * + * @param socket Socket to get hostname for + * @return The hostname for the socket + */ +string Utils::gethostname(int socket) { struct sockaddr_in sa; unsigned int dummy = sizeof sa; struct hostent *hp; - if (getsockname(s,(struct sockaddr *) &sa,&dummy) == -1) + if (getsockname(socket,(struct sockaddr *) &sa,&dummy) == -1) throw Exception("Error getting ip from socket"+Utils::errnotostrerror(errno),__FILE__,__LINE__); hp=gethostbyaddr((const void *)&sa.sin_addr,sizeof sa.sin_addr,AF_INET); @@ -645,8 +654,13 @@ string Utils::gethostname(int s) return string(hp->h_name); } - -void Utils::write_pid(string file,pid_t pid) +/** + * Write process ID to a file + * + * @param file File to write PID to + * @param pid Process ID to write + */ +void Utils::write_pid(const string& file, pid_t pid) { FILE *f; @@ -657,4 +671,3 @@ void Utils::write_pid(string file,pid_t pid) fprintf(f,"%d\n",pid); fclose(f); } - diff --git a/src/win32-service.cpp b/src/win32-service.cpp index 43d71eb..358bae4 100644 --- a/src/win32-service.cpp +++ b/src/win32-service.cpp @@ -18,247 +18,221 @@ * @author Juan José Gutiérrez de Quevedo */ #include +#include +#include #include -using namespace std; +namespace { + // Service configuration constants + constexpr const char* SERVICE_NAME = "hermes anti-spam proxy"; + constexpr const char* SERVICE_SHORT_NAME = "hermes"; + constexpr const char* SERVICE_DESCRIPTION_TEXT = + "An anti-spam proxy using a combination of techniques like greylisting, dnsbl/dnswl, SPF, etc."; -#define SERVICE_NAME "hermes anti-spam proxy" -#define SERVICE_SHORT_NAME "hermes" -#define SERVICE_DESCRIPTION_TEXT "An anti-spam proxy using a combination of techniques like greylisting, dnsbl/dnswl, SPF, etc." + // Global state + SERVICE_STATUS service_status; + SERVICE_STATUS_HANDLE service_status_handle; + extern bool quit; -//macros -#define ChangeServiceStatus(x,y,z) y.dwCurrentState=z; SetServiceStatus(x,&y); -#define MIN(x,y) (((x)<(y))?(x):(y)) -#define cmp(x,y) strncmp(x,y,strlen(y)) -#define msgbox(x,y,z) MessageBox(NULL,x,z SERVICE_NAME,MB_OK|y) -#define winerror(x) msgbox(x,MB_ICONERROR,"Error from ") -#define winmessage(x) msgbox(x,MB_ICONINFORMATION,"Message from ") -#define condfree(x) if(NULL!=x) free(x); -#define safemalloc(x,y,z) do { if(NULL==(x=(z)malloc(y))) { winerror("Error reserving memory"); exit(-1); } memset(x,0,y); } while(0) -#define _(x) x + // Function declarations + static void WINAPI service_main(DWORD argc, LPTSTR* argv); + static void WINAPI handler(DWORD code); + static int service_install(); + static int service_uninstall(); + extern int hermes_main(int argc, char** argv); -/** - * The docs on microsoft's web don't seem very clear, so I have - * looked at the stunnel source code to understand how this thing - * works. What you see here is still original source, but is - * "inspired" by stunnel's source code (gui.c mainly). - * It's the real minimum needed to install, start and stop services - */ + // Helper class for managing command line parameters + class Parameters { + public: + Parameters() { + params_ = std::make_unique(2); + params_[0] = new char[1]; + params_[0][0] = '\0'; + params_[1] = new char[1024](); + } -extern bool quit; + ~Parameters() { + delete[] params_[0]; + delete[] params_[1]; + } -extern int hermes_main(int,char**); + char** get() { return params_.get(); } + char* operator[](size_t index) { return params_[index]; } -SERVICE_STATUS service_status; -SERVICE_STATUS_HANDLE service_status_handle; - -int WINAPI WinMain(HINSTANCE,HINSTANCE,LPSTR,int); -static void WINAPI service_main(DWORD,LPTSTR*); -static void WINAPI handler(DWORD); -static int service_install(); -static int service_uninstall(); - -char **params=NULL; -#define free_params() \ - do \ - { \ - if(NULL!=params) \ - { \ - condfree(params[0]); \ - condfree(params[1]); \ - } \ - condfree(params); \ - } \ - while(0) - -#define init_params() \ - do \ - { \ - free_params(); \ - safemalloc(params,sizeof(char *)*2,char **); \ - safemalloc(params[0],1*sizeof(char),char *); \ - params[0][0]='\0'; \ - safemalloc(params[1],1024*sizeof(char),char *); \ - } \ - while(0) - -int WINAPI WinMain(HINSTANCE instance,HINSTANCE previous_instance,LPSTR cmdline,int cmdshow) -{ - if(!cmp(cmdline,"-service")) - { - SERVICE_TABLE_ENTRY service_table[]={ - {SERVICE_SHORT_NAME,service_main}, - {NULL,NULL} + private: + std::unique_ptr params_; }; - if(0==StartServiceCtrlDispatcher(service_table)) - { - winerror("Error starting service dispatcher."); - return -1; + // Helper functions + void show_error(const char* message) { + MessageBox(NULL, message, ("Error from " SERVICE_NAME), MB_OK | MB_ICONERROR); } - } - else if(!cmp(cmdline,"-install")) - service_install(); - else if(!cmp(cmdline,"-uninstall")) - service_uninstall(); - else - { - //we know that hermes can only have one parameter, so - //just copy it - init_params(); - strncpy(params[1],cmdline,1024); - hermes_main(2,(char **)params); - free_params(); - } - return 0; + void show_message(const char* message) { + MessageBox(NULL, message, ("Message from " SERVICE_NAME), MB_OK | MB_ICONINFORMATION); + } + + void update_service_status(DWORD new_state) { + service_status.dwCurrentState = new_state; + SetServiceStatus(service_status_handle, &service_status); + } } -static int service_install() -{ - SC_HANDLE scm,service_handle; - SERVICE_DESCRIPTION service_description; - char filename[1024]; - string exepath; +int WINAPI WinMain(HINSTANCE instance, HINSTANCE previous_instance, LPSTR cmdline, int cmdshow) { + try { + if (strncmp(cmdline, "-service", strlen("-service")) == 0) { + SERVICE_TABLE_ENTRY service_table[] = { + {const_cast(SERVICE_SHORT_NAME), service_main}, + {NULL, NULL} + }; - if(NULL==(scm=OpenSCManager(NULL,NULL,SC_MANAGER_CREATE_SERVICE))) - { - winerror(_("Error opening connection to the Service Manager.")); - exit(-1); - } - if(0==GetModuleFileName(NULL,filename,sizeof(filename))) - { - winerror(_("Error getting the file name of the process.")); - exit(-1); - } + if (!StartServiceCtrlDispatcher(service_table)) { + throw std::runtime_error("Error starting service dispatcher"); + } + } + else if (strncmp(cmdline, "-install", strlen("-install")) == 0) { + return service_install(); + } + else if (strncmp(cmdline, "-uninstall", strlen("-uninstall")) == 0) { + return service_uninstall(); + } + else { + Parameters params; + strncpy(params[1], cmdline, 1023); + params[1][1023] = '\0'; // Ensure null termination + return hermes_main(2, params.get()); + } + } + catch (const std::exception& e) { + show_error(e.what()); + return -1; + } - exepath=string("\"")+filename+"\" -service"; - - service_handle=CreateService( - scm, //scm handle - SERVICE_SHORT_NAME, //service name - SERVICE_NAME, //display name - SERVICE_ALL_ACCESS, //desired access - SERVICE_WIN32_OWN_PROCESS, //service type - SERVICE_AUTO_START, //start type - SERVICE_ERROR_NORMAL, //error control - exepath.c_str(), //executable path with arguments - NULL, //load group - NULL, //tag for group id - NULL, //dependencies - NULL, //user name - NULL); //password - - if(NULL==service_handle) - { - winerror("Error creating service. Already installed?"); - exit(-1); - } - else - winmessage("Service successfully installed."); - - //createservice doesn't have a field for description - //so we use ChangeServiceConfig2 - service_description.lpDescription=SERVICE_DESCRIPTION_TEXT; - ChangeServiceConfig2(service_handle,SERVICE_CONFIG_DESCRIPTION,(void *)&service_description); - - CloseServiceHandle(service_handle); - CloseServiceHandle(scm); - - return 0; + return 0; } -static int service_uninstall() -{ - SC_HANDLE scm,service_handle; - SERVICE_STATUS status; +static int service_install() { + char filename[1024] = {0}; + if (GetModuleFileName(NULL, filename, sizeof(filename)) == 0) { + throw std::runtime_error("Error getting the file name of the process"); + } - if(NULL==(scm=OpenSCManager(NULL,NULL,SC_MANAGER_CREATE_SERVICE))) - { - winerror(_("Error opening connection to the Service Manager.")); - exit(-1); - } + SC_HANDLE scm = OpenSCManager(NULL, NULL, SC_MANAGER_CREATE_SERVICE); + if (!scm) { + throw std::runtime_error("Error opening connection to the Service Manager"); + } - if(NULL==(service_handle=OpenService(scm,SERVICE_SHORT_NAME,SERVICE_QUERY_STATUS|DELETE))) - { - winerror(_("Error opening service.")); - CloseServiceHandle(scm); - exit(-1); - } + std::string exepath = "\"" + std::string(filename) + "\" -service"; + + SC_HANDLE service_handle = CreateService( + scm, // SCM handle + SERVICE_SHORT_NAME, // Service name + SERVICE_NAME, // Display name + SERVICE_ALL_ACCESS, // Desired access + SERVICE_WIN32_OWN_PROCESS, // Service type + SERVICE_AUTO_START, // Start type + SERVICE_ERROR_NORMAL, // Error control + exepath.c_str(), // Executable path + NULL, NULL, NULL, NULL, NULL // Other parameters + ); + + if (!service_handle) { + CloseServiceHandle(scm); + throw std::runtime_error("Error creating service. Already installed?"); + } + + // Set service description + SERVICE_DESCRIPTION service_description = {const_cast(SERVICE_DESCRIPTION_TEXT)}; + ChangeServiceConfig2(service_handle, SERVICE_CONFIG_DESCRIPTION, &service_description); - if(0==QueryServiceStatus(service_handle,&status)) - { - winerror(_("Error querying service.")); - CloseServiceHandle(scm); CloseServiceHandle(service_handle); - exit(-1); - } - - if(SERVICE_STOPPED!=status.dwCurrentState) - { - winerror(SERVICE_NAME _(" is still running. Stop it before trying to uninstall it.")); CloseServiceHandle(scm); - CloseServiceHandle(service_handle); - exit(-1); - } + show_message("Service successfully installed"); + return 0; +} - if(0==DeleteService(service_handle)) - { - winerror(_("Error deleting service.")); +static int service_uninstall() { + SC_HANDLE scm = OpenSCManager(NULL, NULL, SC_MANAGER_CREATE_SERVICE); + if (!scm) { + throw std::runtime_error("Error opening connection to the Service Manager"); + } + + SC_HANDLE service_handle = OpenService(scm, SERVICE_SHORT_NAME, SERVICE_QUERY_STATUS | DELETE); + if (!service_handle) { + CloseServiceHandle(scm); + throw std::runtime_error("Error opening service"); + } + + SERVICE_STATUS status; + if (!QueryServiceStatus(service_handle, &status)) { + CloseServiceHandle(service_handle); + CloseServiceHandle(scm); + throw std::runtime_error("Error querying service"); + } + + if (status.dwCurrentState != SERVICE_STOPPED) { + CloseServiceHandle(service_handle); + CloseServiceHandle(scm); + throw std::runtime_error(SERVICE_NAME " is still running. Stop it before trying to uninstall it"); + } + + if (!DeleteService(service_handle)) { + CloseServiceHandle(service_handle); + CloseServiceHandle(scm); + throw std::runtime_error("Error deleting service"); + } + + CloseServiceHandle(service_handle); CloseServiceHandle(scm); - CloseServiceHandle(service_handle); - exit(-1); - } - - CloseServiceHandle(scm); - CloseServiceHandle(service_handle); - winmessage(_("Service successfully uninstalled.")); - return 0; + show_message("Service successfully uninstalled"); + return 0; } -static void WINAPI service_main(DWORD argc,LPTSTR *argv) -{ - char *tmpstr; +static void WINAPI service_main(DWORD argc, LPTSTR* argv) { + service_status = { + .dwServiceType = SERVICE_WIN32, + .dwControlsAccepted = SERVICE_ACCEPT_STOP, + .dwWin32ExitCode = NO_ERROR, + .dwServiceSpecificExitCode = NO_ERROR, + .dwCheckPoint = 0, + .dwWaitHint = 0 + }; - //configure service_status structure - service_status.dwServiceType=SERVICE_WIN32; - service_status.dwControlsAccepted=0; - service_status.dwWin32ExitCode=NO_ERROR; - service_status.dwServiceSpecificExitCode=NO_ERROR; - service_status.dwCheckPoint=0; - service_status.dwWaitHint=0; - service_status.dwControlsAccepted|=SERVICE_ACCEPT_STOP; + service_status_handle = RegisterServiceCtrlHandler(SERVICE_SHORT_NAME, handler); + if (!service_status_handle) { + return; + } - service_status_handle=RegisterServiceCtrlHandler(SERVICE_SHORT_NAME,handler); + update_service_status(SERVICE_RUNNING); - if(0!=service_status_handle) - { - //set service status - ChangeServiceStatus(service_status_handle,service_status,SERVICE_RUNNING); + try { + Parameters params; + if (GetModuleFileName(NULL, params[1], 1024) == 0) { + throw std::runtime_error("Error getting module filename"); + } - //get the path to the config file - init_params(); - GetModuleFileName(NULL,params[1],1024); - if(NULL==(tmpstr=strrchr(params[1],'\\'))) { winerror("Error finding default config file."); exit(-1); } - *(++tmpstr)='\0'; - strncat(params[1],"hermes.ini",strlen("hermes.ini")); + char* config_path = strrchr(params[1], '\\'); + if (!config_path) { + throw std::runtime_error("Error finding default config file"); + } - //now start our main program - hermes_main(2,(char **)params); + *(++config_path) = '\0'; + strncat(params[1], "hermes.ini", strlen("hermes.ini")); - free_params(); + hermes_main(2, params.get()); - //when we are here, we have been stopped - ChangeServiceStatus(service_status_handle,service_status,SERVICE_STOP_PENDING); - ChangeServiceStatus(service_status_handle,service_status,SERVICE_STOPPED); - } + update_service_status(SERVICE_STOP_PENDING); + update_service_status(SERVICE_STOPPED); + } + catch (const std::exception& e) { + show_error(e.what()); + update_service_status(SERVICE_STOPPED); + } } -static void WINAPI handler(DWORD code) -{ - if(SERVICE_CONTROL_STOP==code) - { - quit=true; - ChangeServiceStatus(service_status_handle,service_status,SERVICE_STOP_PENDING); - } +static void WINAPI handler(DWORD code) { + if (code == SERVICE_CONTROL_STOP) { + quit = true; + update_service_status(SERVICE_STOP_PENDING); + } }