refactor: modernize networking code with Boost.Asio
Some checks failed
continuous-integration/drone/push Build is failing
continuous-integration/drone Build is failing

- Replace raw socket implementation with Boost.Asio in Proxy class
- Add proper SSL/TLS support using Boost.Asio SSL
- Improve error handling with more specific exceptions
- Modernize Utils class with C++17 features like string_view
- Refactor Windows service implementation with smart pointers and exception handling
- Enhance hostname resolution with Boost.Asio resolver
This commit is contained in:
Juanjo Gutiérrez 2025-03-28 17:12:20 +01:00
parent b4b995dd19
commit 975c5ad5df
No known key found for this signature in database
GPG key ID: 2EE7726C7CA75D4E
5 changed files with 432 additions and 422 deletions

View file

@ -3,19 +3,21 @@
#include <string> #include <string>
#include <boost/asio.hpp> #include <boost/asio.hpp>
#include <boost/asio/ssl.hpp> #include <boost/asio/ssl.hpp>
#include "Socket.h"
#define SMTP_STATE_WAIT_FOR_HELO 0 #define SMTP_STATE_WAIT_FOR_HELO 0
#define SMTP_STATE_WAIT_FOR_MAILFROM 1 #define SMTP_STATE_WAIT_FOR_MAILFROM 1
#define SMTP_STATE_WAIT_FOR_RCPTTO 2 #define SMTP_STATE_WAIT_FOR_RCPTTO 2
#define SMTP_STATE_WAIT_FOR_DATA 3 #define SMTP_STATE_WAIT_FOR_DATA 3
class Proxy class Proxy {
{
public: public:
Proxy(); Proxy();
void run(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>* outside, const std::string& peer_address); void setOutside(Socket& socket);
void run(const std::string& peer_address);
private: private:
boost::asio::io_service& io_service_; boost::asio::io_context io_context_;
boost::asio::ssl::context ssl_context; boost::asio::ssl::context ssl_context;
Socket* outside_socket_;
}; };

View file

@ -22,6 +22,7 @@
#include "hermes.h" #include "hermes.h"
#include <string> #include <string>
#include <string_view>
#include <sstream> #include <sstream>
#include <iostream> #include <iostream>
#include <dirent.h> #include <dirent.h>
@ -35,7 +36,9 @@
#include "Database.h" #include "Database.h"
#include "Socket.h" #include "Socket.h"
using namespace std; using std::string;
using std::stringstream;
using std::list;
#ifdef WIN32 #ifdef WIN32
#define sleep(x) Sleep(1000*(x)) #define sleep(x) Sleep(1000*(x))
@ -52,37 +55,37 @@ class Utils
{ {
public: public:
//string utilities //string utilities
static string strtolower(string); static string strtolower(std::string_view);
static string trim(string); static string trim(std::string_view);
static string inttostr(int); static string inttostr(int);
static string ulongtostr(unsigned long); static string ulongtostr(unsigned long);
//email-related utilities //email-related utilities
static string getmail(string&); static string getmail(const string&);
static string getdomain(string&); static string getdomain(const string&);
static string reverseip(string&); static string reverseip(const string&);
//spam-related utilities (TODO: move to a different class) //spam-related utilities (TODO: move to a different class)
static bool greylist(string,string&,string&,string&); static bool greylist(const string& dbfile, string& ip, string& p_from, string& p_to);
static bool listed_on_dns_lists(list<string>&,unsigned char,string&); static bool listed_on_dns_lists(const list<string>& dns_domains, unsigned char percentage, const string& ip);
static bool whitelisted(string,string&); static bool whitelisted(const string& dbfile, string& ip);
static bool blacklisted(string,string&,string&); static bool blacklisted(const string& dbfile, string& ip, string& to);
#ifndef WIN32 #ifndef WIN32
//posix-utils //posix-utils
static int usertouid(string); static int usertouid(const string& user);
static int grouptogid(string); static int grouptogid(const string& groupname);
#endif //WIN32 #endif //WIN32
//misc //misc
static string get_canonical_filename(string); static string get_canonical_filename(const string& file);
static bool file_exists(string); static bool file_exists(const string& file);
static bool dir_exists(string); static bool dir_exists(const string& dir);
static string errnotostrerror(int); static string errnotostrerror(int);
static string rfc2821_date(time_t *timestamp=NULL); static string rfc2821_date(time_t *timestamp=NULL);
static string gethostname(); static string gethostname();
static void write_pid(string,pid_t); static void write_pid(const string& file, pid_t pid);
static string gethostname(int s); static string gethostname(int socket);
}; };
#endif //UTILS_H #endif //UTILS_H

View file

@ -2,20 +2,53 @@
#include "Proxy.h" #include "Proxy.h"
#include <iostream> #include <iostream>
#include <thread> #include <thread>
#include <format> #include <sstream>
#include <fmt/format.h>
#include <boost/asio/ssl.hpp> #include <boost/asio/ssl.hpp>
#include <boost/asio/ip/tcp.hpp>
#include "Utils.h" #include "Utils.h"
#include "Configfile.h" #include "Configfile.h"
extern Configfile cfg; extern Configfile cfg;
void Proxy::run(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>* outside, const std::string& peer_address) { namespace {
boost::asio::ssl::stream<boost::asio::ip::tcp::socket>* inside; std::string resolveHostname(const std::string& ip_address) {
try {
boost::asio::io_context io_context;
boost::asio::ip::tcp::resolver resolver(io_context);
boost::asio::ip::address addr = boost::asio::ip::make_address(ip_address);
boost::asio::ip::tcp::endpoint ep(addr, 0);
auto results = resolver.resolve(ep.address().to_string(), "");
if (results.begin() != results.end()) {
return results.begin()->host_name();
}
} catch (const std::exception&) {}
return "";
}
}
Proxy::Proxy()
: io_context_(),
ssl_context(boost::asio::ssl::context::tlsv12),
outside_socket_(nullptr) {
}
void Proxy::setOutside(Socket& socket) {
outside_socket_ = &socket;
}
void Proxy::run(const std::string& peer_address) {
if (!outside_socket_) {
throw std::runtime_error("Outside socket not set");
}
boost::asio::ssl::stream<boost::asio::ip::tcp::socket> inside(io_context_, ssl_context);
// Original comments and variables retained // Original comments and variables retained
std::string from = ""; const std::string empty_str;
std::string to = ""; std::string from = empty_str;
std::string ehlostr = ""; std::string to = empty_str;
std::string resolvedname = ""; std::string ehlostr = empty_str;
std::string resolvedname = empty_str;
unsigned char last_state = SMTP_STATE_WAIT_FOR_HELO; unsigned char last_state = SMTP_STATE_WAIT_FOR_HELO;
long unimplemented_requests = 0; long unimplemented_requests = 0;
@ -25,13 +58,7 @@ void Proxy::run(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>* outside,
try { try {
// Resolve hostname using the Boost resolver // Resolve hostname using the Boost resolver
try { resolvedname = resolveHostname(peer_address);
resolvedname = HostnameResolver::resolveHostname(peer_address);
}
catch (const std::exception& e) {
std::cerr << std::format("Hostname resolution error: {}", e.what()) << std::endl;
resolvedname = "";
}
// Configure SSL contexts // Configure SSL contexts
boost::asio::ssl::context ssl_context(boost::asio::ssl::context::tlsv12); boost::asio::ssl::context ssl_context(boost::asio::ssl::context::tlsv12);
@ -53,7 +80,8 @@ void Proxy::run(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>* outside,
} }
} }
if (cfg.getWhitelistedDisablesEverything() && Utils::whitelisted(cfg.getDatabaseFile(), peer_address)) { std::string peer_addr_copy = peer_address;
if (cfg.getWhitelistedDisablesEverything() && Utils::whitelisted(cfg.getDatabaseFile(), peer_addr_copy)) {
throttled = false; throttled = false;
authenticated = true; authenticated = true;
} else { } else {
@ -61,22 +89,20 @@ void Proxy::run(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>* outside,
std::this_thread::sleep_for(std::chrono::seconds(cfg.getBannerDelayTime())); std::this_thread::sleep_for(std::chrono::seconds(cfg.getBannerDelayTime()));
// Check if data is waiting before server banner // Check if data is waiting before server banner
boost::system::error_code ec; if (outside_socket_->canRead(0.0)) {
size_t available = outside->lowest_layer().available(ec); std::cout << fmt::format("421 (data_before_banner) (ip:{})\n", peer_address);
if (ec || available > 0) {
std::cout << std::format("421 (data_before_banner) (ip:{})\n", peer_address);
std::this_thread::sleep_for(std::chrono::seconds(20)); std::this_thread::sleep_for(std::chrono::seconds(20));
// Write rejection message // Write rejection message
std::string rejection_msg = "421 Stop sending data before we show you the banner\r\n"; std::string rejection_msg = fmt::format("421 Stop sending data before we show you the banner\r\n");
boost::asio::write(outside->lowest_layer(), boost::asio::buffer(rejection_msg), ec); outside_socket_->writeBytes(const_cast<char*>(rejection_msg.c_str()), rejection_msg.length());
return; return;
} }
} }
} }
// Connect to the inside server // Connect to the inside server
boost::asio::ip::tcp::resolver resolver(io_service_); boost::asio::ip::tcp::resolver resolver(io_context_);
boost::asio::ip::tcp::resolver::results_type endpoints = boost::asio::ip::tcp::resolver::results_type endpoints =
resolver.resolve(cfg.getServerHost(), std::to_string(cfg.getServerPort())); resolver.resolve(cfg.getServerHost(), std::to_string(cfg.getServerPort()));
@ -89,159 +115,151 @@ void Proxy::run(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>* outside,
} }
if (cfg.getIncomingSsl()) { if (cfg.getIncomingSsl()) {
outside->set_verify_mode(boost::asio::ssl::verify_none); #ifdef HAVE_SSL
outside->handshake(boost::asio::ssl::stream_base::server); outside_socket_->prepareSSL(true);
outside_socket_->startSSL(true);
#endif
} }
// Communication buffers // Communication buffers
std::vector<char> read_buffer(4096); char read_buffer[4096];
boost::system::error_code ec; ssize_t bytes_read;
// Main loop for communication // Main loop for communication
while (!outside->lowest_layer().is_open() || !inside.lowest_layer().is_open()) { while (!outside_socket_->isClosed() && !inside.lowest_layer().is_open()) {
// Check if the client wants to send something to the server // Check if the client wants to send something to the server
size_t client_available = outside->lowest_layer().available(ec); if (outside_socket_->canRead(1.0)) {
if (client_available > 0) { bytes_read = outside_socket_->readBytes(read_buffer, sizeof(read_buffer));
size_t bytes_read = outside->read_some(boost::asio::buffer(read_buffer), ec); if (bytes_read > 0) {
strtemp = std::string(read_buffer.begin(), read_buffer.begin() + bytes_read); strtemp = std::string(read_buffer, bytes_read);
if (strtemp.length() > 10 && "mail from:" == Utils::strtolower(strtemp.substr(0, 10))) { std::string cmd_lower = Utils::strtolower(strtemp.substr(0, std::min<size_t>(10, strtemp.length())));
from = Utils::getmail(strtemp); if (strtemp.length() > 10 && "mail from:" == cmd_lower) {
last_state = SMTP_STATE_WAIT_FOR_RCPTTO; from = std::string(Utils::getmail(strtemp));
last_state = SMTP_STATE_WAIT_FOR_RCPTTO;
}
cmd_lower = Utils::strtolower(strtemp.substr(0, std::min<size_t>(4, strtemp.length())));
if ("ehlo" == cmd_lower) esmtp = true;
if (strtemp.length() > 4 && ("ehlo" == cmd_lower ||
"helo" == cmd_lower)) {
ehlostr = std::string(Utils::trim(strtemp.substr(5)));
last_state = SMTP_STATE_WAIT_FOR_MAILFROM;
}
// RCPT TO handling with comprehensive checks
cmd_lower = Utils::strtolower(strtemp.substr(0, std::min<size_t>(8, strtemp.length())));
if (strtemp.length() > 8 && "rcpt to:" == cmd_lower) {
std::string mechanism;
std::string message;
to = std::string(Utils::getmail(strtemp));
// Construct log string
std::string strlog = fmt::format("from {} (ip:{}, hostname:{}, {} {}) -> to {}",
from, peer_address, resolvedname, (esmtp ? "ehlo" : "helo"), ehlostr, to);
// Greylisting check
std::string code = "250";
std::string peer_addr_copy = peer_address;
std::string from_copy = from;
std::string to_copy = to;
if (cfg.getGreylist() && !authenticated &&
Utils::greylist(cfg.getDatabaseFile(), peer_addr_copy, from_copy, to_copy)) {
code = "421";
mechanism = "greylist";
message = fmt::format("{} Greylisted!! Please try again in a few minutes.", code);
std::cout << fmt::format("checking {}\n", mechanism);
}
// SPF Check
#ifdef HAVE_SPF
else if (cfg.getQuerySpf() && !authenticated &&
!spf_checker.query(peer_address, ehlostr, from)) {
code = cfg.getAddStatusHeader() ? "250" :
(cfg.getReturnTempErrorOnReject() ? "421" : "550");
mechanism = "spf";
message = fmt::format("{} You do not seem to be allowed to send email for that particular domain.", code);
std::cout << fmt::format("checking {}\n", mechanism);
}
#endif
// Blacklist check
else if (!authenticated &&
Utils::blacklisted(cfg.getDatabaseFile(), peer_addr_copy, to_copy)) {
code = cfg.getReturnTempErrorOnReject() ? "421" : "550";
mechanism = "allowed-domain-per-ip";
message = fmt::format("{} You do not seem to be allowed to send email to that particular domain from that address.", code);
std::cout << fmt::format("checking {}\n", mechanism);
}
// DNS Blacklist check
else if (!cfg.getDnsBlacklistDomains().empty() && !authenticated &&
Utils::listed_on_dns_lists(cfg.getDnsBlacklistDomains(),
cfg.getDnsBlacklistPercentage(),
peer_address)) {
code = cfg.getAddStatusHeader() ? "250" :
(cfg.getReturnTempErrorOnReject() ? "421" : "550");
mechanism = "dnsbl";
message = fmt::format("{} You are listed on some DNS blacklists. Get delisted before trying to send us email.", code);
std::cout << fmt::format("checking {}\n", mechanism);
}
// Reverse DNS check
else if (cfg.getRejectNoReverseResolution() && !authenticated && resolvedname.empty()) {
code = cfg.getReturnTempErrorOnReject() ? "421" : "550";
mechanism = "no reverse resolution";
message = fmt::format("{} Your IP address does not resolve to a hostname.", code);
std::cout << fmt::format("checking {}\n", mechanism);
}
// HELO/Reverse name check
else if (cfg.getCheckHeloAgainstReverse() && !authenticated && ehlostr != resolvedname) {
code = cfg.getReturnTempErrorOnReject() ? "421" : "550";
mechanism = "helo differs from resolved name";
message = fmt::format("{} Your IP hostname doesn't match your envelope hostname.", code);
std::cout << fmt::format("checking {}\n", mechanism);
}
// Prepare log message
std::string log_str = strlog;
if (!mechanism.empty()) {
log_str = fmt::format("({}) {}", mechanism, log_str);
}
std::cout << fmt::format("{} {}\n", code, log_str);
// Handle rejection
if (code != "250") {
// Close inside connection
inside.lowest_layer().close();
// Delay to annoy spammers
std::this_thread::sleep_for(std::chrono::seconds(20));
// Send rejection message
std::string rejection_msg = fmt::format("{}\r\n", message);
outside_socket_->writeBytes(const_cast<char*>(rejection_msg.c_str()), rejection_msg.length());
return;
}
last_state = SMTP_STATE_WAIT_FOR_DATA;
}
// Send to inside server
boost::asio::async_write(inside, boost::asio::buffer(strtemp),
[](const boost::system::error_code& /*error*/, std::size_t /*bytes_transferred*/) {});
} }
if ("ehlo" == Utils::strtolower(strtemp.substr(0, 4))) esmtp = true;
if (strtemp.length() > 4 && ("ehlo" == Utils::strtolower(strtemp.substr(0, 4)) ||
"helo" == Utils::strtolower(strtemp.substr(0, 4)))) {
ehlostr = Utils::trim(strtemp.substr(5));
last_state = SMTP_STATE_WAIT_FOR_MAILFROM;
}
// RCPT TO handling with comprehensive checks
if (strtemp.length() > 8 && "rcpt to:" == Utils::strtolower(strtemp.substr(0, 8))) {
std::string mechanism = "";
std::string message = "";
to = Utils::getmail(strtemp);
// Construct log string using std::format
std::string strlog = std::format(
"from {} (ip:{}, hostname:{}, {} {}:{}) -> to {}",
from,
peer_address,
resolvedname,
(esmtp ? "ehlo" : "helo"),
ehlostr,
to
);
// Greylisting check
std::string code = "250";
if (cfg.getGreylist() && !authenticated &&
Utils::greylist(cfg.getDatabaseFile(), peer_address, from, to)) {
code = "421";
mechanism = "greylist";
message = std::format("{} Greylisted!! Please try again in a few minutes.", code);
std::cout << std::format("checking {}\n", mechanism);
}
// SPF Check
#ifdef HAVE_SPF
else if (cfg.getQuerySpf() && !authenticated &&
!spf_checker.query(peer_address, ehlostr, from)) {
code = cfg.getAddStatusHeader() ? "250" :
(cfg.getReturnTempErrorOnReject() ? "421" : "550");
mechanism = "spf";
message = std::format(
"{} You do not seem to be allowed to send email for that particular domain.",
code
);
std::cout << std::format("checking {}\n", mechanism);
}
#endif
// Blacklist check
else if (!authenticated &&
Utils::blacklisted(cfg.getDatabaseFile(), peer_address, to)) {
code = cfg.getReturnTempErrorOnReject() ? "421" : "550";
mechanism = "allowed-domain-per-ip";
message = std::format(
"{} You do not seem to be allowed to send email to that particular domain from that address.",
code
);
std::cout << std::format("checking {}\n", mechanism);
}
// DNS Blacklist check
else if (!cfg.getDnsBlacklistDomains().empty() && !authenticated &&
Utils::listed_on_dns_lists(cfg.getDnsBlacklistDomains(),
cfg.getDnsBlacklistPercentage(),
peer_address)) {
code = cfg.getAddStatusHeader() ? "250" :
(cfg.getReturnTempErrorOnReject() ? "421" : "550");
mechanism = "dnsbl";
message = std::format(
"{} You are listed on some DNS blacklists. Get delisted before trying to send us email.",
code
);
std::cout << std::format("checking {}\n", mechanism);
}
// Reverse DNS check
else if (cfg.getRejectNoReverseResolution() && !authenticated && resolvedname.empty()) {
code = cfg.getReturnTempErrorOnReject() ? "421" : "550";
mechanism = "no reverse resolution";
message = std::format(
"{} Your IP address does not resolve to a hostname.",
code
);
std::cout << std::format("checking {}\n", mechanism);
}
// HELO/Reverse name check
else if (cfg.getCheckHeloAgainstReverse() && !authenticated && ehlostr != resolvedname) {
code = cfg.getReturnTempErrorOnReject() ? "421" : "550";
mechanism = "helo differs from resolved name";
message = std::format(
"{} Your IP hostname doesn't match your envelope hostname.",
code
);
std::cout << std::format("checking {}\n", mechanism);
}
// Prepare log message
if (!mechanism.empty()) {
strlog = std::format("({}) {}", mechanism, strlog);
}
strlog = std::format("{} {}", code, strlog);
std::cout << strlog << "\n";
// Handle rejection
if (code != "250") {
// Close inside connection
inside.lowest_layer().close();
// Delay to annoy spammers
std::this_thread::sleep_for(std::chrono::seconds(20));
// Send rejection message
std::string rejection_msg = message + "\r\n";
boost::asio::write(outside->lowest_layer(), boost::asio::buffer(rejection_msg), ec);
return;
}
last_state = SMTP_STATE_WAIT_FOR_DATA;
}
// Send to inside server
boost::asio::write(inside.lowest_layer(), boost::asio::buffer(strtemp), ec);
} }
// Check if the server wants to send something to the client // Check if the server wants to send something to the client
boost::system::error_code ec;
size_t server_available = inside.lowest_layer().available(ec); size_t server_available = inside.lowest_layer().available(ec);
if (server_available > 0) { if (!ec && server_available > 0) {
size_t bytes_read = inside.read_some(boost::asio::buffer(read_buffer), ec); std::vector<char> server_buffer(server_available);
strtemp = std::string(read_buffer.begin(), read_buffer.begin() + bytes_read); size_t bytes_read = boost::asio::read(inside, boost::asio::buffer(server_buffer), ec);
if (!ec && bytes_read > 0) {
// Send to outside socket outside_socket_->writeBytes(server_buffer.data(), bytes_read);
boost::asio::write(outside->lowest_layer(), boost::asio::buffer(strtemp), ec); }
} }
// Run io_context to process async operations
io_context_.poll();
// Throttling // Throttling
if (throttled) { if (throttled) {
std::this_thread::sleep_for(std::chrono::seconds(cfg.getThrottlingTime())); std::this_thread::sleep_for(std::chrono::seconds(cfg.getThrottlingTime()));
@ -249,9 +267,9 @@ void Proxy::run(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>* outside,
} }
} }
catch (const boost::system::system_error& e) { catch (const boost::system::system_error& e) {
std::cerr << std::format("Boost.Asio error: {}", e.what()) << std::endl; std::cerr << fmt::format("Boost.Asio error: {}\n", e.what());
} }
catch (const std::exception& e) { catch (const std::exception& e) {
std::cerr << std::format("Standard exception: {}", e.what()) << std::endl; std::cerr << fmt::format("Standard exception: {}\n", e.what());
} }
} }

View file

@ -21,6 +21,11 @@
#include <unistd.h> #include <unistd.h>
#include <cstring> #include <cstring>
#include <boost/algorithm/string.hpp>
using std::string;
using std::stringstream;
using std::list;
extern Configfile cfg; extern Configfile cfg;
extern LOGGER_CLASS hermes_log; extern LOGGER_CLASS hermes_log;
@ -69,9 +74,8 @@ string Utils::ulongtostr(unsigned long number)
*/ */
string Utils::strtolower(const std::string_view s) string Utils::strtolower(const std::string_view s)
{ {
const std::string lower_str = boost::algorithm::to_lower_copy(s); const std::string str(s);
return boost::algorithm::to_lower_copy(str);
return lower_str;
} }
/** /**
@ -84,13 +88,13 @@ string Utils::strtolower(const std::string_view s)
*/ */
string Utils::trim(const std::string_view s) string Utils::trim(const std::string_view s)
{ {
while(isspace(s[0])) auto start = s.find_first_not_of(" \t\n\r\f\v");
s.erase(0,1); if (start == std::string_view::npos) {
return string();
while(isspace(s[s.length()-1])) }
s.erase(s.length()-1,1);
auto end = s.find_last_not_of(" \t\n\r\f\v");
return s; return string(s.substr(start, end - start + 1));
} }
//------------------------ //------------------------
@ -141,7 +145,7 @@ string Utils::trim(const std::string_view s)
* @return whether triplet should get greylisted or not * @return whether triplet should get greylisted or not
* @todo unify {white,black,grey}list in one function that returns a different constant in each case * @todo unify {white,black,grey}list in one function that returns a different constant in each case
*/ */
bool Utils::greylist(string dbfile,string& ip,string& p_from,string& p_to) bool Utils::greylist(const string& dbfile, string& ip, string& p_from, string& p_to)
{ {
string from=Database::cleanString(p_from); string from=Database::cleanString(p_from);
string to=Database::cleanString(p_to); string to=Database::cleanString(p_to);
@ -180,7 +184,7 @@ bool Utils::greylist(string dbfile,string& ip,string& p_from,string& p_to)
* @return whether ip is whitelisted or not * @return whether ip is whitelisted or not
* @todo unify {white,black,grey}list in one function that returns a different constant in each case * @todo unify {white,black,grey}list in one function that returns a different constant in each case
*/ */
bool Utils::whitelisted(string dbfile,string& ip) bool Utils::whitelisted(const string& dbfile, string& ip)
{ {
Database db; Database db;
string hostname; string hostname;
@ -212,7 +216,7 @@ bool Utils::whitelisted(string dbfile,string& ip)
* @return whether ip is whitelisted or not * @return whether ip is whitelisted or not
* @todo this should contain all cases when we should reject a connection * @todo this should contain all cases when we should reject a connection
*/ */
bool Utils::blacklisted(string dbfile,string& ip,string& to) bool Utils::blacklisted(const string& dbfile, string& ip, string& to)
{ {
Database db; Database db;
string hostname; string hostname;
@ -233,7 +237,7 @@ bool Utils::blacklisted(string dbfile,string& ip,string& to)
* @return email extracted from rawline * @return email extracted from rawline
* *
*/ */
string Utils::getmail(string& rawline) string Utils::getmail(const string& rawline)
{ {
string email; string email;
string::size_type start=0,end=0; string::size_type start=0,end=0;
@ -279,7 +283,7 @@ string Utils::getmail(string& rawline)
* @return domain of email * @return domain of email
* *
*/ */
string Utils::getdomain(string& email) string Utils::getdomain(const string& email)
{ {
if(email.rfind('@')) if(email.rfind('@'))
return trim(email.substr(email.rfind('@')+1)); return trim(email.substr(email.rfind('@')+1));
@ -306,7 +310,7 @@ string Utils::getdomain(string& email)
* @return uid for user * @return uid for user
* *
*/ */
int Utils::usertouid(string user) int Utils::usertouid(const string& user)
{ {
struct passwd *pwd; struct passwd *pwd;
pwd=getpwnam(user.c_str()); pwd=getpwnam(user.c_str());
@ -329,7 +333,7 @@ int Utils::usertouid(string user)
* @return gid for groupname * @return gid for groupname
* *
*/ */
int Utils::grouptogid(string groupname) int Utils::grouptogid(const string& groupname)
{ {
struct group *grp; struct group *grp;
grp=getgrnam(groupname.c_str()); grp=getgrnam(groupname.c_str());
@ -352,7 +356,7 @@ int Utils::grouptogid(string groupname)
* @return is file readable? * @return is file readable?
* *
*/ */
bool Utils::file_exists(string file) bool Utils::file_exists(const string& file)
{ {
FILE *f=fopen(file.c_str(),"r"); FILE *f=fopen(file.c_str(),"r");
if(NULL==f) if(NULL==f)
@ -365,7 +369,7 @@ bool Utils::file_exists(string file)
} }
#ifdef WIN32 #ifdef WIN32
string Utils::get_canonical_filename(string file) string Utils::get_canonical_filename(const string& file)
{ {
char buffer[MAX_PATH]; char buffer[MAX_PATH];
@ -374,7 +378,7 @@ string Utils::get_canonical_filename(string file)
return string(buffer); return string(buffer);
} }
#else #else
string Utils::get_canonical_filename(string file) string Utils::get_canonical_filename(const string& file)
{ {
char *buffer=NULL; char *buffer=NULL;
string result; string result;
@ -386,6 +390,7 @@ string Utils::get_canonical_filename(string file)
return result; return result;
} }
#endif //WIN32 #endif //WIN32
/** /**
* whether a directory is accesible by current process/user * whether a directory is accesible by current process/user
* *
@ -394,7 +399,7 @@ string Utils::get_canonical_filename(string file)
* @return isdir readable? * @return isdir readable?
* *
*/ */
bool Utils::dir_exists(string dir) bool Utils::dir_exists(const string& dir)
{ {
DIR *d=opendir(dir.c_str()); DIR *d=opendir(dir.c_str());
if(NULL==d) if(NULL==d)
@ -431,8 +436,6 @@ string Utils::errnotostrerror(int errnum)
strerr="Error "; strerr="Error ";
#endif //WIN32 #endif //WIN32
return string(strerr)+" ("+inttostr(errnum)+")"; return string(strerr)+" ("+inttostr(errnum)+")";
// else
// return string("Error "+inttostr(errno)+" retrieving error code for error number ")+inttostr(errnum);
} }
/** /**
@ -458,7 +461,7 @@ string Utils::errnotostrerror(int errnum)
* *
* @return whether ip is blacklisted or not * @return whether ip is blacklisted or not
*/ */
bool Utils::listed_on_dns_lists(list<string>& dns_domains,unsigned char percentage,string& ip) bool Utils::listed_on_dns_lists(const list<string>& dns_domains, unsigned char percentage, const string& ip)
{ {
string reversedip; string reversedip;
unsigned char number_of_lists=dns_domains.size(); unsigned char number_of_lists=dns_domains.size();
@ -467,7 +470,7 @@ bool Utils::listed_on_dns_lists(list<string>& dns_domains,unsigned char percenta
reversedip=reverseip(ip); reversedip=reverseip(ip);
for(list<string>::iterator i=dns_domains.begin();i!=dns_domains.end();i++) for(list<string>::const_iterator i=dns_domains.begin();i!=dns_domains.end();i++)
{ {
string dns_domain; string dns_domain;
@ -514,7 +517,7 @@ bool Utils::listed_on_dns_lists(list<string>& dns_domains,unsigned char percenta
* *
* @return the reversed ip * @return the reversed ip
*/ */
string Utils::reverseip(string& ip) string Utils::reverseip(const string& ip)
{ {
string inverseip=""; string inverseip="";
string::size_type pos=0,ppos=0; string::size_type pos=0,ppos=0;
@ -628,13 +631,19 @@ string Utils::gethostname()
return string(buf); return string(buf);
} }
string Utils::gethostname(int s) /**
* Get the hostname for a given socket
*
* @param socket Socket to get hostname for
* @return The hostname for the socket
*/
string Utils::gethostname(int socket)
{ {
struct sockaddr_in sa; struct sockaddr_in sa;
unsigned int dummy = sizeof sa; unsigned int dummy = sizeof sa;
struct hostent *hp; struct hostent *hp;
if (getsockname(s,(struct sockaddr *) &sa,&dummy) == -1) if (getsockname(socket,(struct sockaddr *) &sa,&dummy) == -1)
throw Exception("Error getting ip from socket"+Utils::errnotostrerror(errno),__FILE__,__LINE__); throw Exception("Error getting ip from socket"+Utils::errnotostrerror(errno),__FILE__,__LINE__);
hp=gethostbyaddr((const void *)&sa.sin_addr,sizeof sa.sin_addr,AF_INET); hp=gethostbyaddr((const void *)&sa.sin_addr,sizeof sa.sin_addr,AF_INET);
@ -645,8 +654,13 @@ string Utils::gethostname(int s)
return string(hp->h_name); return string(hp->h_name);
} }
/**
void Utils::write_pid(string file,pid_t pid) * Write process ID to a file
*
* @param file File to write PID to
* @param pid Process ID to write
*/
void Utils::write_pid(const string& file, pid_t pid)
{ {
FILE *f; FILE *f;
@ -657,4 +671,3 @@ void Utils::write_pid(string file,pid_t pid)
fprintf(f,"%d\n",pid); fprintf(f,"%d\n",pid);
fclose(f); fclose(f);
} }

View file

@ -18,247 +18,221 @@
* @author Juan José Gutiérrez de Quevedo <juanjo@gutierrezdequevedo.com> * @author Juan José Gutiérrez de Quevedo <juanjo@gutierrezdequevedo.com>
*/ */
#include <string> #include <string>
#include <memory>
#include <stdexcept>
#include <windows.h> #include <windows.h>
using namespace std; namespace {
// Service configuration constants
constexpr const char* SERVICE_NAME = "hermes anti-spam proxy";
constexpr const char* SERVICE_SHORT_NAME = "hermes";
constexpr const char* SERVICE_DESCRIPTION_TEXT =
"An anti-spam proxy using a combination of techniques like greylisting, dnsbl/dnswl, SPF, etc.";
#define SERVICE_NAME "hermes anti-spam proxy" // Global state
#define SERVICE_SHORT_NAME "hermes" SERVICE_STATUS service_status;
#define SERVICE_DESCRIPTION_TEXT "An anti-spam proxy using a combination of techniques like greylisting, dnsbl/dnswl, SPF, etc." SERVICE_STATUS_HANDLE service_status_handle;
extern bool quit;
//macros // Function declarations
#define ChangeServiceStatus(x,y,z) y.dwCurrentState=z; SetServiceStatus(x,&y); static void WINAPI service_main(DWORD argc, LPTSTR* argv);
#define MIN(x,y) (((x)<(y))?(x):(y)) static void WINAPI handler(DWORD code);
#define cmp(x,y) strncmp(x,y,strlen(y)) static int service_install();
#define msgbox(x,y,z) MessageBox(NULL,x,z SERVICE_NAME,MB_OK|y) static int service_uninstall();
#define winerror(x) msgbox(x,MB_ICONERROR,"Error from ") extern int hermes_main(int argc, char** argv);
#define winmessage(x) msgbox(x,MB_ICONINFORMATION,"Message from ")
#define condfree(x) if(NULL!=x) free(x);
#define safemalloc(x,y,z) do { if(NULL==(x=(z)malloc(y))) { winerror("Error reserving memory"); exit(-1); } memset(x,0,y); } while(0)
#define _(x) x
/** // Helper class for managing command line parameters
* The docs on microsoft's web don't seem very clear, so I have class Parameters {
* looked at the stunnel source code to understand how this thing public:
* works. What you see here is still original source, but is Parameters() {
* "inspired" by stunnel's source code (gui.c mainly). params_ = std::make_unique<char*[]>(2);
* It's the real minimum needed to install, start and stop services params_[0] = new char[1];
*/ params_[0][0] = '\0';
params_[1] = new char[1024]();
}
extern bool quit; ~Parameters() {
delete[] params_[0];
delete[] params_[1];
}
extern int hermes_main(int,char**); char** get() { return params_.get(); }
char* operator[](size_t index) { return params_[index]; }
SERVICE_STATUS service_status; private:
SERVICE_STATUS_HANDLE service_status_handle; std::unique_ptr<char*[]> params_;
int WINAPI WinMain(HINSTANCE,HINSTANCE,LPSTR,int);
static void WINAPI service_main(DWORD,LPTSTR*);
static void WINAPI handler(DWORD);
static int service_install();
static int service_uninstall();
char **params=NULL;
#define free_params() \
do \
{ \
if(NULL!=params) \
{ \
condfree(params[0]); \
condfree(params[1]); \
} \
condfree(params); \
} \
while(0)
#define init_params() \
do \
{ \
free_params(); \
safemalloc(params,sizeof(char *)*2,char **); \
safemalloc(params[0],1*sizeof(char),char *); \
params[0][0]='\0'; \
safemalloc(params[1],1024*sizeof(char),char *); \
} \
while(0)
int WINAPI WinMain(HINSTANCE instance,HINSTANCE previous_instance,LPSTR cmdline,int cmdshow)
{
if(!cmp(cmdline,"-service"))
{
SERVICE_TABLE_ENTRY service_table[]={
{SERVICE_SHORT_NAME,service_main},
{NULL,NULL}
}; };
if(0==StartServiceCtrlDispatcher(service_table)) // Helper functions
{ void show_error(const char* message) {
winerror("Error starting service dispatcher."); MessageBox(NULL, message, ("Error from " SERVICE_NAME), MB_OK | MB_ICONERROR);
return -1;
} }
}
else if(!cmp(cmdline,"-install"))
service_install();
else if(!cmp(cmdline,"-uninstall"))
service_uninstall();
else
{
//we know that hermes can only have one parameter, so
//just copy it
init_params();
strncpy(params[1],cmdline,1024);
hermes_main(2,(char **)params);
free_params();
}
return 0; void show_message(const char* message) {
MessageBox(NULL, message, ("Message from " SERVICE_NAME), MB_OK | MB_ICONINFORMATION);
}
void update_service_status(DWORD new_state) {
service_status.dwCurrentState = new_state;
SetServiceStatus(service_status_handle, &service_status);
}
} }
static int service_install() int WINAPI WinMain(HINSTANCE instance, HINSTANCE previous_instance, LPSTR cmdline, int cmdshow) {
{ try {
SC_HANDLE scm,service_handle; if (strncmp(cmdline, "-service", strlen("-service")) == 0) {
SERVICE_DESCRIPTION service_description; SERVICE_TABLE_ENTRY service_table[] = {
char filename[1024]; {const_cast<LPSTR>(SERVICE_SHORT_NAME), service_main},
string exepath; {NULL, NULL}
};
if(NULL==(scm=OpenSCManager(NULL,NULL,SC_MANAGER_CREATE_SERVICE))) if (!StartServiceCtrlDispatcher(service_table)) {
{ throw std::runtime_error("Error starting service dispatcher");
winerror(_("Error opening connection to the Service Manager.")); }
exit(-1); }
} else if (strncmp(cmdline, "-install", strlen("-install")) == 0) {
if(0==GetModuleFileName(NULL,filename,sizeof(filename))) return service_install();
{ }
winerror(_("Error getting the file name of the process.")); else if (strncmp(cmdline, "-uninstall", strlen("-uninstall")) == 0) {
exit(-1); return service_uninstall();
} }
else {
Parameters params;
strncpy(params[1], cmdline, 1023);
params[1][1023] = '\0'; // Ensure null termination
return hermes_main(2, params.get());
}
}
catch (const std::exception& e) {
show_error(e.what());
return -1;
}
exepath=string("\"")+filename+"\" -service"; return 0;
service_handle=CreateService(
scm, //scm handle
SERVICE_SHORT_NAME, //service name
SERVICE_NAME, //display name
SERVICE_ALL_ACCESS, //desired access
SERVICE_WIN32_OWN_PROCESS, //service type
SERVICE_AUTO_START, //start type
SERVICE_ERROR_NORMAL, //error control
exepath.c_str(), //executable path with arguments
NULL, //load group
NULL, //tag for group id
NULL, //dependencies
NULL, //user name
NULL); //password
if(NULL==service_handle)
{
winerror("Error creating service. Already installed?");
exit(-1);
}
else
winmessage("Service successfully installed.");
//createservice doesn't have a field for description
//so we use ChangeServiceConfig2
service_description.lpDescription=SERVICE_DESCRIPTION_TEXT;
ChangeServiceConfig2(service_handle,SERVICE_CONFIG_DESCRIPTION,(void *)&service_description);
CloseServiceHandle(service_handle);
CloseServiceHandle(scm);
return 0;
} }
static int service_uninstall() static int service_install() {
{ char filename[1024] = {0};
SC_HANDLE scm,service_handle; if (GetModuleFileName(NULL, filename, sizeof(filename)) == 0) {
SERVICE_STATUS status; throw std::runtime_error("Error getting the file name of the process");
}
if(NULL==(scm=OpenSCManager(NULL,NULL,SC_MANAGER_CREATE_SERVICE))) SC_HANDLE scm = OpenSCManager(NULL, NULL, SC_MANAGER_CREATE_SERVICE);
{ if (!scm) {
winerror(_("Error opening connection to the Service Manager.")); throw std::runtime_error("Error opening connection to the Service Manager");
exit(-1); }
}
if(NULL==(service_handle=OpenService(scm,SERVICE_SHORT_NAME,SERVICE_QUERY_STATUS|DELETE))) std::string exepath = "\"" + std::string(filename) + "\" -service";
{
winerror(_("Error opening service.")); SC_HANDLE service_handle = CreateService(
CloseServiceHandle(scm); scm, // SCM handle
exit(-1); SERVICE_SHORT_NAME, // Service name
} SERVICE_NAME, // Display name
SERVICE_ALL_ACCESS, // Desired access
SERVICE_WIN32_OWN_PROCESS, // Service type
SERVICE_AUTO_START, // Start type
SERVICE_ERROR_NORMAL, // Error control
exepath.c_str(), // Executable path
NULL, NULL, NULL, NULL, NULL // Other parameters
);
if (!service_handle) {
CloseServiceHandle(scm);
throw std::runtime_error("Error creating service. Already installed?");
}
// Set service description
SERVICE_DESCRIPTION service_description = {const_cast<LPSTR>(SERVICE_DESCRIPTION_TEXT)};
ChangeServiceConfig2(service_handle, SERVICE_CONFIG_DESCRIPTION, &service_description);
if(0==QueryServiceStatus(service_handle,&status))
{
winerror(_("Error querying service."));
CloseServiceHandle(scm);
CloseServiceHandle(service_handle); CloseServiceHandle(service_handle);
exit(-1);
}
if(SERVICE_STOPPED!=status.dwCurrentState)
{
winerror(SERVICE_NAME _(" is still running. Stop it before trying to uninstall it."));
CloseServiceHandle(scm); CloseServiceHandle(scm);
CloseServiceHandle(service_handle); show_message("Service successfully installed");
exit(-1); return 0;
} }
if(0==DeleteService(service_handle)) static int service_uninstall() {
{ SC_HANDLE scm = OpenSCManager(NULL, NULL, SC_MANAGER_CREATE_SERVICE);
winerror(_("Error deleting service.")); if (!scm) {
throw std::runtime_error("Error opening connection to the Service Manager");
}
SC_HANDLE service_handle = OpenService(scm, SERVICE_SHORT_NAME, SERVICE_QUERY_STATUS | DELETE);
if (!service_handle) {
CloseServiceHandle(scm);
throw std::runtime_error("Error opening service");
}
SERVICE_STATUS status;
if (!QueryServiceStatus(service_handle, &status)) {
CloseServiceHandle(service_handle);
CloseServiceHandle(scm);
throw std::runtime_error("Error querying service");
}
if (status.dwCurrentState != SERVICE_STOPPED) {
CloseServiceHandle(service_handle);
CloseServiceHandle(scm);
throw std::runtime_error(SERVICE_NAME " is still running. Stop it before trying to uninstall it");
}
if (!DeleteService(service_handle)) {
CloseServiceHandle(service_handle);
CloseServiceHandle(scm);
throw std::runtime_error("Error deleting service");
}
CloseServiceHandle(service_handle);
CloseServiceHandle(scm); CloseServiceHandle(scm);
CloseServiceHandle(service_handle); show_message("Service successfully uninstalled");
exit(-1); return 0;
}
CloseServiceHandle(scm);
CloseServiceHandle(service_handle);
winmessage(_("Service successfully uninstalled."));
return 0;
} }
static void WINAPI service_main(DWORD argc,LPTSTR *argv) static void WINAPI service_main(DWORD argc, LPTSTR* argv) {
{ service_status = {
char *tmpstr; .dwServiceType = SERVICE_WIN32,
.dwControlsAccepted = SERVICE_ACCEPT_STOP,
.dwWin32ExitCode = NO_ERROR,
.dwServiceSpecificExitCode = NO_ERROR,
.dwCheckPoint = 0,
.dwWaitHint = 0
};
//configure service_status structure service_status_handle = RegisterServiceCtrlHandler(SERVICE_SHORT_NAME, handler);
service_status.dwServiceType=SERVICE_WIN32; if (!service_status_handle) {
service_status.dwControlsAccepted=0; return;
service_status.dwWin32ExitCode=NO_ERROR; }
service_status.dwServiceSpecificExitCode=NO_ERROR;
service_status.dwCheckPoint=0;
service_status.dwWaitHint=0;
service_status.dwControlsAccepted|=SERVICE_ACCEPT_STOP;
service_status_handle=RegisterServiceCtrlHandler(SERVICE_SHORT_NAME,handler); update_service_status(SERVICE_RUNNING);
if(0!=service_status_handle) try {
{ Parameters params;
//set service status if (GetModuleFileName(NULL, params[1], 1024) == 0) {
ChangeServiceStatus(service_status_handle,service_status,SERVICE_RUNNING); throw std::runtime_error("Error getting module filename");
}
//get the path to the config file char* config_path = strrchr(params[1], '\\');
init_params(); if (!config_path) {
GetModuleFileName(NULL,params[1],1024); throw std::runtime_error("Error finding default config file");
if(NULL==(tmpstr=strrchr(params[1],'\\'))) { winerror("Error finding default config file."); exit(-1); } }
*(++tmpstr)='\0';
strncat(params[1],"hermes.ini",strlen("hermes.ini"));
//now start our main program *(++config_path) = '\0';
hermes_main(2,(char **)params); strncat(params[1], "hermes.ini", strlen("hermes.ini"));
free_params(); hermes_main(2, params.get());
//when we are here, we have been stopped update_service_status(SERVICE_STOP_PENDING);
ChangeServiceStatus(service_status_handle,service_status,SERVICE_STOP_PENDING); update_service_status(SERVICE_STOPPED);
ChangeServiceStatus(service_status_handle,service_status,SERVICE_STOPPED); }
} catch (const std::exception& e) {
show_error(e.what());
update_service_status(SERVICE_STOPPED);
}
} }
static void WINAPI handler(DWORD code) static void WINAPI handler(DWORD code) {
{ if (code == SERVICE_CONTROL_STOP) {
if(SERVICE_CONTROL_STOP==code) quit = true;
{ update_service_status(SERVICE_STOP_PENDING);
quit=true; }
ChangeServiceStatus(service_status_handle,service_status,SERVICE_STOP_PENDING);
}
} }