refactor: modernize networking code with Boost.Asio
Some checks failed
continuous-integration/drone/push Build is failing
continuous-integration/drone Build is failing

- Replace raw socket implementation with Boost.Asio in Proxy class
- Add proper SSL/TLS support using Boost.Asio SSL
- Improve error handling with more specific exceptions
- Modernize Utils class with C++17 features like string_view
- Refactor Windows service implementation with smart pointers and exception handling
- Enhance hostname resolution with Boost.Asio resolver
This commit is contained in:
Juanjo Gutiérrez 2025-03-28 17:12:20 +01:00
parent b4b995dd19
commit 975c5ad5df
No known key found for this signature in database
GPG key ID: 2EE7726C7CA75D4E
5 changed files with 432 additions and 422 deletions

View file

@ -3,19 +3,21 @@
#include <string>
#include <boost/asio.hpp>
#include <boost/asio/ssl.hpp>
#include "Socket.h"
#define SMTP_STATE_WAIT_FOR_HELO 0
#define SMTP_STATE_WAIT_FOR_MAILFROM 1
#define SMTP_STATE_WAIT_FOR_RCPTTO 2
#define SMTP_STATE_WAIT_FOR_DATA 3
class Proxy
{
class Proxy {
public:
Proxy();
void run(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>* outside, const std::string& peer_address);
void setOutside(Socket& socket);
void run(const std::string& peer_address);
private:
boost::asio::io_service& io_service_;
boost::asio::io_context io_context_;
boost::asio::ssl::context ssl_context;
Socket* outside_socket_;
};

View file

@ -22,6 +22,7 @@
#include "hermes.h"
#include <string>
#include <string_view>
#include <sstream>
#include <iostream>
#include <dirent.h>
@ -35,7 +36,9 @@
#include "Database.h"
#include "Socket.h"
using namespace std;
using std::string;
using std::stringstream;
using std::list;
#ifdef WIN32
#define sleep(x) Sleep(1000*(x))
@ -52,37 +55,37 @@ class Utils
{
public:
//string utilities
static string strtolower(string);
static string trim(string);
static string strtolower(std::string_view);
static string trim(std::string_view);
static string inttostr(int);
static string ulongtostr(unsigned long);
//email-related utilities
static string getmail(string&);
static string getdomain(string&);
static string reverseip(string&);
static string getmail(const string&);
static string getdomain(const string&);
static string reverseip(const string&);
//spam-related utilities (TODO: move to a different class)
static bool greylist(string,string&,string&,string&);
static bool listed_on_dns_lists(list<string>&,unsigned char,string&);
static bool whitelisted(string,string&);
static bool blacklisted(string,string&,string&);
static bool greylist(const string& dbfile, string& ip, string& p_from, string& p_to);
static bool listed_on_dns_lists(const list<string>& dns_domains, unsigned char percentage, const string& ip);
static bool whitelisted(const string& dbfile, string& ip);
static bool blacklisted(const string& dbfile, string& ip, string& to);
#ifndef WIN32
//posix-utils
static int usertouid(string);
static int grouptogid(string);
static int usertouid(const string& user);
static int grouptogid(const string& groupname);
#endif //WIN32
//misc
static string get_canonical_filename(string);
static bool file_exists(string);
static bool dir_exists(string);
static string get_canonical_filename(const string& file);
static bool file_exists(const string& file);
static bool dir_exists(const string& dir);
static string errnotostrerror(int);
static string rfc2821_date(time_t *timestamp=NULL);
static string gethostname();
static void write_pid(string,pid_t);
static string gethostname(int s);
static void write_pid(const string& file, pid_t pid);
static string gethostname(int socket);
};
#endif //UTILS_H

View file

@ -2,20 +2,53 @@
#include "Proxy.h"
#include <iostream>
#include <thread>
#include <format>
#include <sstream>
#include <fmt/format.h>
#include <boost/asio/ssl.hpp>
#include <boost/asio/ip/tcp.hpp>
#include "Utils.h"
#include "Configfile.h"
extern Configfile cfg;
void Proxy::run(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>* outside, const std::string& peer_address) {
boost::asio::ssl::stream<boost::asio::ip::tcp::socket>* inside;
namespace {
std::string resolveHostname(const std::string& ip_address) {
try {
boost::asio::io_context io_context;
boost::asio::ip::tcp::resolver resolver(io_context);
boost::asio::ip::address addr = boost::asio::ip::make_address(ip_address);
boost::asio::ip::tcp::endpoint ep(addr, 0);
auto results = resolver.resolve(ep.address().to_string(), "");
if (results.begin() != results.end()) {
return results.begin()->host_name();
}
} catch (const std::exception&) {}
return "";
}
}
Proxy::Proxy()
: io_context_(),
ssl_context(boost::asio::ssl::context::tlsv12),
outside_socket_(nullptr) {
}
void Proxy::setOutside(Socket& socket) {
outside_socket_ = &socket;
}
void Proxy::run(const std::string& peer_address) {
if (!outside_socket_) {
throw std::runtime_error("Outside socket not set");
}
boost::asio::ssl::stream<boost::asio::ip::tcp::socket> inside(io_context_, ssl_context);
// Original comments and variables retained
std::string from = "";
std::string to = "";
std::string ehlostr = "";
std::string resolvedname = "";
const std::string empty_str;
std::string from = empty_str;
std::string to = empty_str;
std::string ehlostr = empty_str;
std::string resolvedname = empty_str;
unsigned char last_state = SMTP_STATE_WAIT_FOR_HELO;
long unimplemented_requests = 0;
@ -25,13 +58,7 @@ void Proxy::run(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>* outside,
try {
// Resolve hostname using the Boost resolver
try {
resolvedname = HostnameResolver::resolveHostname(peer_address);
}
catch (const std::exception& e) {
std::cerr << std::format("Hostname resolution error: {}", e.what()) << std::endl;
resolvedname = "";
}
resolvedname = resolveHostname(peer_address);
// Configure SSL contexts
boost::asio::ssl::context ssl_context(boost::asio::ssl::context::tlsv12);
@ -53,7 +80,8 @@ void Proxy::run(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>* outside,
}
}
if (cfg.getWhitelistedDisablesEverything() && Utils::whitelisted(cfg.getDatabaseFile(), peer_address)) {
std::string peer_addr_copy = peer_address;
if (cfg.getWhitelistedDisablesEverything() && Utils::whitelisted(cfg.getDatabaseFile(), peer_addr_copy)) {
throttled = false;
authenticated = true;
} else {
@ -61,22 +89,20 @@ void Proxy::run(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>* outside,
std::this_thread::sleep_for(std::chrono::seconds(cfg.getBannerDelayTime()));
// Check if data is waiting before server banner
boost::system::error_code ec;
size_t available = outside->lowest_layer().available(ec);
if (ec || available > 0) {
std::cout << std::format("421 (data_before_banner) (ip:{})\n", peer_address);
if (outside_socket_->canRead(0.0)) {
std::cout << fmt::format("421 (data_before_banner) (ip:{})\n", peer_address);
std::this_thread::sleep_for(std::chrono::seconds(20));
// Write rejection message
std::string rejection_msg = "421 Stop sending data before we show you the banner\r\n";
boost::asio::write(outside->lowest_layer(), boost::asio::buffer(rejection_msg), ec);
std::string rejection_msg = fmt::format("421 Stop sending data before we show you the banner\r\n");
outside_socket_->writeBytes(const_cast<char*>(rejection_msg.c_str()), rejection_msg.length());
return;
}
}
}
// Connect to the inside server
boost::asio::ip::tcp::resolver resolver(io_service_);
boost::asio::ip::tcp::resolver resolver(io_context_);
boost::asio::ip::tcp::resolver::results_type endpoints =
resolver.resolve(cfg.getServerHost(), std::to_string(cfg.getServerPort()));
@ -89,159 +115,151 @@ void Proxy::run(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>* outside,
}
if (cfg.getIncomingSsl()) {
outside->set_verify_mode(boost::asio::ssl::verify_none);
outside->handshake(boost::asio::ssl::stream_base::server);
#ifdef HAVE_SSL
outside_socket_->prepareSSL(true);
outside_socket_->startSSL(true);
#endif
}
// Communication buffers
std::vector<char> read_buffer(4096);
boost::system::error_code ec;
char read_buffer[4096];
ssize_t bytes_read;
// Main loop for communication
while (!outside->lowest_layer().is_open() || !inside.lowest_layer().is_open()) {
while (!outside_socket_->isClosed() && !inside.lowest_layer().is_open()) {
// Check if the client wants to send something to the server
size_t client_available = outside->lowest_layer().available(ec);
if (client_available > 0) {
size_t bytes_read = outside->read_some(boost::asio::buffer(read_buffer), ec);
strtemp = std::string(read_buffer.begin(), read_buffer.begin() + bytes_read);
if (outside_socket_->canRead(1.0)) {
bytes_read = outside_socket_->readBytes(read_buffer, sizeof(read_buffer));
if (bytes_read > 0) {
strtemp = std::string(read_buffer, bytes_read);
if (strtemp.length() > 10 && "mail from:" == Utils::strtolower(strtemp.substr(0, 10))) {
from = Utils::getmail(strtemp);
last_state = SMTP_STATE_WAIT_FOR_RCPTTO;
std::string cmd_lower = Utils::strtolower(strtemp.substr(0, std::min<size_t>(10, strtemp.length())));
if (strtemp.length() > 10 && "mail from:" == cmd_lower) {
from = std::string(Utils::getmail(strtemp));
last_state = SMTP_STATE_WAIT_FOR_RCPTTO;
}
cmd_lower = Utils::strtolower(strtemp.substr(0, std::min<size_t>(4, strtemp.length())));
if ("ehlo" == cmd_lower) esmtp = true;
if (strtemp.length() > 4 && ("ehlo" == cmd_lower ||
"helo" == cmd_lower)) {
ehlostr = std::string(Utils::trim(strtemp.substr(5)));
last_state = SMTP_STATE_WAIT_FOR_MAILFROM;
}
// RCPT TO handling with comprehensive checks
cmd_lower = Utils::strtolower(strtemp.substr(0, std::min<size_t>(8, strtemp.length())));
if (strtemp.length() > 8 && "rcpt to:" == cmd_lower) {
std::string mechanism;
std::string message;
to = std::string(Utils::getmail(strtemp));
// Construct log string
std::string strlog = fmt::format("from {} (ip:{}, hostname:{}, {} {}) -> to {}",
from, peer_address, resolvedname, (esmtp ? "ehlo" : "helo"), ehlostr, to);
// Greylisting check
std::string code = "250";
std::string peer_addr_copy = peer_address;
std::string from_copy = from;
std::string to_copy = to;
if (cfg.getGreylist() && !authenticated &&
Utils::greylist(cfg.getDatabaseFile(), peer_addr_copy, from_copy, to_copy)) {
code = "421";
mechanism = "greylist";
message = fmt::format("{} Greylisted!! Please try again in a few minutes.", code);
std::cout << fmt::format("checking {}\n", mechanism);
}
// SPF Check
#ifdef HAVE_SPF
else if (cfg.getQuerySpf() && !authenticated &&
!spf_checker.query(peer_address, ehlostr, from)) {
code = cfg.getAddStatusHeader() ? "250" :
(cfg.getReturnTempErrorOnReject() ? "421" : "550");
mechanism = "spf";
message = fmt::format("{} You do not seem to be allowed to send email for that particular domain.", code);
std::cout << fmt::format("checking {}\n", mechanism);
}
#endif
// Blacklist check
else if (!authenticated &&
Utils::blacklisted(cfg.getDatabaseFile(), peer_addr_copy, to_copy)) {
code = cfg.getReturnTempErrorOnReject() ? "421" : "550";
mechanism = "allowed-domain-per-ip";
message = fmt::format("{} You do not seem to be allowed to send email to that particular domain from that address.", code);
std::cout << fmt::format("checking {}\n", mechanism);
}
// DNS Blacklist check
else if (!cfg.getDnsBlacklistDomains().empty() && !authenticated &&
Utils::listed_on_dns_lists(cfg.getDnsBlacklistDomains(),
cfg.getDnsBlacklistPercentage(),
peer_address)) {
code = cfg.getAddStatusHeader() ? "250" :
(cfg.getReturnTempErrorOnReject() ? "421" : "550");
mechanism = "dnsbl";
message = fmt::format("{} You are listed on some DNS blacklists. Get delisted before trying to send us email.", code);
std::cout << fmt::format("checking {}\n", mechanism);
}
// Reverse DNS check
else if (cfg.getRejectNoReverseResolution() && !authenticated && resolvedname.empty()) {
code = cfg.getReturnTempErrorOnReject() ? "421" : "550";
mechanism = "no reverse resolution";
message = fmt::format("{} Your IP address does not resolve to a hostname.", code);
std::cout << fmt::format("checking {}\n", mechanism);
}
// HELO/Reverse name check
else if (cfg.getCheckHeloAgainstReverse() && !authenticated && ehlostr != resolvedname) {
code = cfg.getReturnTempErrorOnReject() ? "421" : "550";
mechanism = "helo differs from resolved name";
message = fmt::format("{} Your IP hostname doesn't match your envelope hostname.", code);
std::cout << fmt::format("checking {}\n", mechanism);
}
// Prepare log message
std::string log_str = strlog;
if (!mechanism.empty()) {
log_str = fmt::format("({}) {}", mechanism, log_str);
}
std::cout << fmt::format("{} {}\n", code, log_str);
// Handle rejection
if (code != "250") {
// Close inside connection
inside.lowest_layer().close();
// Delay to annoy spammers
std::this_thread::sleep_for(std::chrono::seconds(20));
// Send rejection message
std::string rejection_msg = fmt::format("{}\r\n", message);
outside_socket_->writeBytes(const_cast<char*>(rejection_msg.c_str()), rejection_msg.length());
return;
}
last_state = SMTP_STATE_WAIT_FOR_DATA;
}
// Send to inside server
boost::asio::async_write(inside, boost::asio::buffer(strtemp),
[](const boost::system::error_code& /*error*/, std::size_t /*bytes_transferred*/) {});
}
if ("ehlo" == Utils::strtolower(strtemp.substr(0, 4))) esmtp = true;
if (strtemp.length() > 4 && ("ehlo" == Utils::strtolower(strtemp.substr(0, 4)) ||
"helo" == Utils::strtolower(strtemp.substr(0, 4)))) {
ehlostr = Utils::trim(strtemp.substr(5));
last_state = SMTP_STATE_WAIT_FOR_MAILFROM;
}
// RCPT TO handling with comprehensive checks
if (strtemp.length() > 8 && "rcpt to:" == Utils::strtolower(strtemp.substr(0, 8))) {
std::string mechanism = "";
std::string message = "";
to = Utils::getmail(strtemp);
// Construct log string using std::format
std::string strlog = std::format(
"from {} (ip:{}, hostname:{}, {} {}:{}) -> to {}",
from,
peer_address,
resolvedname,
(esmtp ? "ehlo" : "helo"),
ehlostr,
to
);
// Greylisting check
std::string code = "250";
if (cfg.getGreylist() && !authenticated &&
Utils::greylist(cfg.getDatabaseFile(), peer_address, from, to)) {
code = "421";
mechanism = "greylist";
message = std::format("{} Greylisted!! Please try again in a few minutes.", code);
std::cout << std::format("checking {}\n", mechanism);
}
// SPF Check
#ifdef HAVE_SPF
else if (cfg.getQuerySpf() && !authenticated &&
!spf_checker.query(peer_address, ehlostr, from)) {
code = cfg.getAddStatusHeader() ? "250" :
(cfg.getReturnTempErrorOnReject() ? "421" : "550");
mechanism = "spf";
message = std::format(
"{} You do not seem to be allowed to send email for that particular domain.",
code
);
std::cout << std::format("checking {}\n", mechanism);
}
#endif
// Blacklist check
else if (!authenticated &&
Utils::blacklisted(cfg.getDatabaseFile(), peer_address, to)) {
code = cfg.getReturnTempErrorOnReject() ? "421" : "550";
mechanism = "allowed-domain-per-ip";
message = std::format(
"{} You do not seem to be allowed to send email to that particular domain from that address.",
code
);
std::cout << std::format("checking {}\n", mechanism);
}
// DNS Blacklist check
else if (!cfg.getDnsBlacklistDomains().empty() && !authenticated &&
Utils::listed_on_dns_lists(cfg.getDnsBlacklistDomains(),
cfg.getDnsBlacklistPercentage(),
peer_address)) {
code = cfg.getAddStatusHeader() ? "250" :
(cfg.getReturnTempErrorOnReject() ? "421" : "550");
mechanism = "dnsbl";
message = std::format(
"{} You are listed on some DNS blacklists. Get delisted before trying to send us email.",
code
);
std::cout << std::format("checking {}\n", mechanism);
}
// Reverse DNS check
else if (cfg.getRejectNoReverseResolution() && !authenticated && resolvedname.empty()) {
code = cfg.getReturnTempErrorOnReject() ? "421" : "550";
mechanism = "no reverse resolution";
message = std::format(
"{} Your IP address does not resolve to a hostname.",
code
);
std::cout << std::format("checking {}\n", mechanism);
}
// HELO/Reverse name check
else if (cfg.getCheckHeloAgainstReverse() && !authenticated && ehlostr != resolvedname) {
code = cfg.getReturnTempErrorOnReject() ? "421" : "550";
mechanism = "helo differs from resolved name";
message = std::format(
"{} Your IP hostname doesn't match your envelope hostname.",
code
);
std::cout << std::format("checking {}\n", mechanism);
}
// Prepare log message
if (!mechanism.empty()) {
strlog = std::format("({}) {}", mechanism, strlog);
}
strlog = std::format("{} {}", code, strlog);
std::cout << strlog << "\n";
// Handle rejection
if (code != "250") {
// Close inside connection
inside.lowest_layer().close();
// Delay to annoy spammers
std::this_thread::sleep_for(std::chrono::seconds(20));
// Send rejection message
std::string rejection_msg = message + "\r\n";
boost::asio::write(outside->lowest_layer(), boost::asio::buffer(rejection_msg), ec);
return;
}
last_state = SMTP_STATE_WAIT_FOR_DATA;
}
// Send to inside server
boost::asio::write(inside.lowest_layer(), boost::asio::buffer(strtemp), ec);
}
// Check if the server wants to send something to the client
boost::system::error_code ec;
size_t server_available = inside.lowest_layer().available(ec);
if (server_available > 0) {
size_t bytes_read = inside.read_some(boost::asio::buffer(read_buffer), ec);
strtemp = std::string(read_buffer.begin(), read_buffer.begin() + bytes_read);
// Send to outside socket
boost::asio::write(outside->lowest_layer(), boost::asio::buffer(strtemp), ec);
if (!ec && server_available > 0) {
std::vector<char> server_buffer(server_available);
size_t bytes_read = boost::asio::read(inside, boost::asio::buffer(server_buffer), ec);
if (!ec && bytes_read > 0) {
outside_socket_->writeBytes(server_buffer.data(), bytes_read);
}
}
// Run io_context to process async operations
io_context_.poll();
// Throttling
if (throttled) {
std::this_thread::sleep_for(std::chrono::seconds(cfg.getThrottlingTime()));
@ -249,9 +267,9 @@ void Proxy::run(boost::asio::ssl::stream<boost::asio::ip::tcp::socket>* outside,
}
}
catch (const boost::system::system_error& e) {
std::cerr << std::format("Boost.Asio error: {}", e.what()) << std::endl;
std::cerr << fmt::format("Boost.Asio error: {}\n", e.what());
}
catch (const std::exception& e) {
std::cerr << std::format("Standard exception: {}", e.what()) << std::endl;
std::cerr << fmt::format("Standard exception: {}\n", e.what());
}
}

View file

@ -21,6 +21,11 @@
#include <unistd.h>
#include <cstring>
#include <boost/algorithm/string.hpp>
using std::string;
using std::stringstream;
using std::list;
extern Configfile cfg;
extern LOGGER_CLASS hermes_log;
@ -69,9 +74,8 @@ string Utils::ulongtostr(unsigned long number)
*/
string Utils::strtolower(const std::string_view s)
{
const std::string lower_str = boost::algorithm::to_lower_copy(s);
return lower_str;
const std::string str(s);
return boost::algorithm::to_lower_copy(str);
}
/**
@ -84,13 +88,13 @@ string Utils::strtolower(const std::string_view s)
*/
string Utils::trim(const std::string_view s)
{
while(isspace(s[0]))
s.erase(0,1);
while(isspace(s[s.length()-1]))
s.erase(s.length()-1,1);
return s;
auto start = s.find_first_not_of(" \t\n\r\f\v");
if (start == std::string_view::npos) {
return string();
}
auto end = s.find_last_not_of(" \t\n\r\f\v");
return string(s.substr(start, end - start + 1));
}
//------------------------
@ -141,7 +145,7 @@ string Utils::trim(const std::string_view s)
* @return whether triplet should get greylisted or not
* @todo unify {white,black,grey}list in one function that returns a different constant in each case
*/
bool Utils::greylist(string dbfile,string& ip,string& p_from,string& p_to)
bool Utils::greylist(const string& dbfile, string& ip, string& p_from, string& p_to)
{
string from=Database::cleanString(p_from);
string to=Database::cleanString(p_to);
@ -180,7 +184,7 @@ bool Utils::greylist(string dbfile,string& ip,string& p_from,string& p_to)
* @return whether ip is whitelisted or not
* @todo unify {white,black,grey}list in one function that returns a different constant in each case
*/
bool Utils::whitelisted(string dbfile,string& ip)
bool Utils::whitelisted(const string& dbfile, string& ip)
{
Database db;
string hostname;
@ -212,7 +216,7 @@ bool Utils::whitelisted(string dbfile,string& ip)
* @return whether ip is whitelisted or not
* @todo this should contain all cases when we should reject a connection
*/
bool Utils::blacklisted(string dbfile,string& ip,string& to)
bool Utils::blacklisted(const string& dbfile, string& ip, string& to)
{
Database db;
string hostname;
@ -233,7 +237,7 @@ bool Utils::blacklisted(string dbfile,string& ip,string& to)
* @return email extracted from rawline
*
*/
string Utils::getmail(string& rawline)
string Utils::getmail(const string& rawline)
{
string email;
string::size_type start=0,end=0;
@ -279,7 +283,7 @@ string Utils::getmail(string& rawline)
* @return domain of email
*
*/
string Utils::getdomain(string& email)
string Utils::getdomain(const string& email)
{
if(email.rfind('@'))
return trim(email.substr(email.rfind('@')+1));
@ -306,7 +310,7 @@ string Utils::getdomain(string& email)
* @return uid for user
*
*/
int Utils::usertouid(string user)
int Utils::usertouid(const string& user)
{
struct passwd *pwd;
pwd=getpwnam(user.c_str());
@ -329,7 +333,7 @@ int Utils::usertouid(string user)
* @return gid for groupname
*
*/
int Utils::grouptogid(string groupname)
int Utils::grouptogid(const string& groupname)
{
struct group *grp;
grp=getgrnam(groupname.c_str());
@ -352,7 +356,7 @@ int Utils::grouptogid(string groupname)
* @return is file readable?
*
*/
bool Utils::file_exists(string file)
bool Utils::file_exists(const string& file)
{
FILE *f=fopen(file.c_str(),"r");
if(NULL==f)
@ -365,7 +369,7 @@ bool Utils::file_exists(string file)
}
#ifdef WIN32
string Utils::get_canonical_filename(string file)
string Utils::get_canonical_filename(const string& file)
{
char buffer[MAX_PATH];
@ -374,7 +378,7 @@ string Utils::get_canonical_filename(string file)
return string(buffer);
}
#else
string Utils::get_canonical_filename(string file)
string Utils::get_canonical_filename(const string& file)
{
char *buffer=NULL;
string result;
@ -386,6 +390,7 @@ string Utils::get_canonical_filename(string file)
return result;
}
#endif //WIN32
/**
* whether a directory is accesible by current process/user
*
@ -394,7 +399,7 @@ string Utils::get_canonical_filename(string file)
* @return isdir readable?
*
*/
bool Utils::dir_exists(string dir)
bool Utils::dir_exists(const string& dir)
{
DIR *d=opendir(dir.c_str());
if(NULL==d)
@ -431,8 +436,6 @@ string Utils::errnotostrerror(int errnum)
strerr="Error ";
#endif //WIN32
return string(strerr)+" ("+inttostr(errnum)+")";
// else
// return string("Error "+inttostr(errno)+" retrieving error code for error number ")+inttostr(errnum);
}
/**
@ -458,7 +461,7 @@ string Utils::errnotostrerror(int errnum)
*
* @return whether ip is blacklisted or not
*/
bool Utils::listed_on_dns_lists(list<string>& dns_domains,unsigned char percentage,string& ip)
bool Utils::listed_on_dns_lists(const list<string>& dns_domains, unsigned char percentage, const string& ip)
{
string reversedip;
unsigned char number_of_lists=dns_domains.size();
@ -467,7 +470,7 @@ bool Utils::listed_on_dns_lists(list<string>& dns_domains,unsigned char percenta
reversedip=reverseip(ip);
for(list<string>::iterator i=dns_domains.begin();i!=dns_domains.end();i++)
for(list<string>::const_iterator i=dns_domains.begin();i!=dns_domains.end();i++)
{
string dns_domain;
@ -514,7 +517,7 @@ bool Utils::listed_on_dns_lists(list<string>& dns_domains,unsigned char percenta
*
* @return the reversed ip
*/
string Utils::reverseip(string& ip)
string Utils::reverseip(const string& ip)
{
string inverseip="";
string::size_type pos=0,ppos=0;
@ -628,13 +631,19 @@ string Utils::gethostname()
return string(buf);
}
string Utils::gethostname(int s)
/**
* Get the hostname for a given socket
*
* @param socket Socket to get hostname for
* @return The hostname for the socket
*/
string Utils::gethostname(int socket)
{
struct sockaddr_in sa;
unsigned int dummy = sizeof sa;
struct hostent *hp;
if (getsockname(s,(struct sockaddr *) &sa,&dummy) == -1)
if (getsockname(socket,(struct sockaddr *) &sa,&dummy) == -1)
throw Exception("Error getting ip from socket"+Utils::errnotostrerror(errno),__FILE__,__LINE__);
hp=gethostbyaddr((const void *)&sa.sin_addr,sizeof sa.sin_addr,AF_INET);
@ -645,8 +654,13 @@ string Utils::gethostname(int s)
return string(hp->h_name);
}
void Utils::write_pid(string file,pid_t pid)
/**
* Write process ID to a file
*
* @param file File to write PID to
* @param pid Process ID to write
*/
void Utils::write_pid(const string& file, pid_t pid)
{
FILE *f;
@ -657,4 +671,3 @@ void Utils::write_pid(string file,pid_t pid)
fprintf(f,"%d\n",pid);
fclose(f);
}

View file

@ -18,247 +18,221 @@
* @author Juan José Gutiérrez de Quevedo <juanjo@gutierrezdequevedo.com>
*/
#include <string>
#include <memory>
#include <stdexcept>
#include <windows.h>
using namespace std;
namespace {
// Service configuration constants
constexpr const char* SERVICE_NAME = "hermes anti-spam proxy";
constexpr const char* SERVICE_SHORT_NAME = "hermes";
constexpr const char* SERVICE_DESCRIPTION_TEXT =
"An anti-spam proxy using a combination of techniques like greylisting, dnsbl/dnswl, SPF, etc.";
#define SERVICE_NAME "hermes anti-spam proxy"
#define SERVICE_SHORT_NAME "hermes"
#define SERVICE_DESCRIPTION_TEXT "An anti-spam proxy using a combination of techniques like greylisting, dnsbl/dnswl, SPF, etc."
// Global state
SERVICE_STATUS service_status;
SERVICE_STATUS_HANDLE service_status_handle;
extern bool quit;
//macros
#define ChangeServiceStatus(x,y,z) y.dwCurrentState=z; SetServiceStatus(x,&y);
#define MIN(x,y) (((x)<(y))?(x):(y))
#define cmp(x,y) strncmp(x,y,strlen(y))
#define msgbox(x,y,z) MessageBox(NULL,x,z SERVICE_NAME,MB_OK|y)
#define winerror(x) msgbox(x,MB_ICONERROR,"Error from ")
#define winmessage(x) msgbox(x,MB_ICONINFORMATION,"Message from ")
#define condfree(x) if(NULL!=x) free(x);
#define safemalloc(x,y,z) do { if(NULL==(x=(z)malloc(y))) { winerror("Error reserving memory"); exit(-1); } memset(x,0,y); } while(0)
#define _(x) x
// Function declarations
static void WINAPI service_main(DWORD argc, LPTSTR* argv);
static void WINAPI handler(DWORD code);
static int service_install();
static int service_uninstall();
extern int hermes_main(int argc, char** argv);
/**
* The docs on microsoft's web don't seem very clear, so I have
* looked at the stunnel source code to understand how this thing
* works. What you see here is still original source, but is
* "inspired" by stunnel's source code (gui.c mainly).
* It's the real minimum needed to install, start and stop services
*/
// Helper class for managing command line parameters
class Parameters {
public:
Parameters() {
params_ = std::make_unique<char*[]>(2);
params_[0] = new char[1];
params_[0][0] = '\0';
params_[1] = new char[1024]();
}
extern bool quit;
~Parameters() {
delete[] params_[0];
delete[] params_[1];
}
extern int hermes_main(int,char**);
char** get() { return params_.get(); }
char* operator[](size_t index) { return params_[index]; }
SERVICE_STATUS service_status;
SERVICE_STATUS_HANDLE service_status_handle;
int WINAPI WinMain(HINSTANCE,HINSTANCE,LPSTR,int);
static void WINAPI service_main(DWORD,LPTSTR*);
static void WINAPI handler(DWORD);
static int service_install();
static int service_uninstall();
char **params=NULL;
#define free_params() \
do \
{ \
if(NULL!=params) \
{ \
condfree(params[0]); \
condfree(params[1]); \
} \
condfree(params); \
} \
while(0)
#define init_params() \
do \
{ \
free_params(); \
safemalloc(params,sizeof(char *)*2,char **); \
safemalloc(params[0],1*sizeof(char),char *); \
params[0][0]='\0'; \
safemalloc(params[1],1024*sizeof(char),char *); \
} \
while(0)
int WINAPI WinMain(HINSTANCE instance,HINSTANCE previous_instance,LPSTR cmdline,int cmdshow)
{
if(!cmp(cmdline,"-service"))
{
SERVICE_TABLE_ENTRY service_table[]={
{SERVICE_SHORT_NAME,service_main},
{NULL,NULL}
private:
std::unique_ptr<char*[]> params_;
};
if(0==StartServiceCtrlDispatcher(service_table))
{
winerror("Error starting service dispatcher.");
return -1;
// Helper functions
void show_error(const char* message) {
MessageBox(NULL, message, ("Error from " SERVICE_NAME), MB_OK | MB_ICONERROR);
}
}
else if(!cmp(cmdline,"-install"))
service_install();
else if(!cmp(cmdline,"-uninstall"))
service_uninstall();
else
{
//we know that hermes can only have one parameter, so
//just copy it
init_params();
strncpy(params[1],cmdline,1024);
hermes_main(2,(char **)params);
free_params();
}
return 0;
void show_message(const char* message) {
MessageBox(NULL, message, ("Message from " SERVICE_NAME), MB_OK | MB_ICONINFORMATION);
}
void update_service_status(DWORD new_state) {
service_status.dwCurrentState = new_state;
SetServiceStatus(service_status_handle, &service_status);
}
}
static int service_install()
{
SC_HANDLE scm,service_handle;
SERVICE_DESCRIPTION service_description;
char filename[1024];
string exepath;
int WINAPI WinMain(HINSTANCE instance, HINSTANCE previous_instance, LPSTR cmdline, int cmdshow) {
try {
if (strncmp(cmdline, "-service", strlen("-service")) == 0) {
SERVICE_TABLE_ENTRY service_table[] = {
{const_cast<LPSTR>(SERVICE_SHORT_NAME), service_main},
{NULL, NULL}
};
if(NULL==(scm=OpenSCManager(NULL,NULL,SC_MANAGER_CREATE_SERVICE)))
{
winerror(_("Error opening connection to the Service Manager."));
exit(-1);
}
if(0==GetModuleFileName(NULL,filename,sizeof(filename)))
{
winerror(_("Error getting the file name of the process."));
exit(-1);
}
if (!StartServiceCtrlDispatcher(service_table)) {
throw std::runtime_error("Error starting service dispatcher");
}
}
else if (strncmp(cmdline, "-install", strlen("-install")) == 0) {
return service_install();
}
else if (strncmp(cmdline, "-uninstall", strlen("-uninstall")) == 0) {
return service_uninstall();
}
else {
Parameters params;
strncpy(params[1], cmdline, 1023);
params[1][1023] = '\0'; // Ensure null termination
return hermes_main(2, params.get());
}
}
catch (const std::exception& e) {
show_error(e.what());
return -1;
}
exepath=string("\"")+filename+"\" -service";
service_handle=CreateService(
scm, //scm handle
SERVICE_SHORT_NAME, //service name
SERVICE_NAME, //display name
SERVICE_ALL_ACCESS, //desired access
SERVICE_WIN32_OWN_PROCESS, //service type
SERVICE_AUTO_START, //start type
SERVICE_ERROR_NORMAL, //error control
exepath.c_str(), //executable path with arguments
NULL, //load group
NULL, //tag for group id
NULL, //dependencies
NULL, //user name
NULL); //password
if(NULL==service_handle)
{
winerror("Error creating service. Already installed?");
exit(-1);
}
else
winmessage("Service successfully installed.");
//createservice doesn't have a field for description
//so we use ChangeServiceConfig2
service_description.lpDescription=SERVICE_DESCRIPTION_TEXT;
ChangeServiceConfig2(service_handle,SERVICE_CONFIG_DESCRIPTION,(void *)&service_description);
CloseServiceHandle(service_handle);
CloseServiceHandle(scm);
return 0;
return 0;
}
static int service_uninstall()
{
SC_HANDLE scm,service_handle;
SERVICE_STATUS status;
static int service_install() {
char filename[1024] = {0};
if (GetModuleFileName(NULL, filename, sizeof(filename)) == 0) {
throw std::runtime_error("Error getting the file name of the process");
}
if(NULL==(scm=OpenSCManager(NULL,NULL,SC_MANAGER_CREATE_SERVICE)))
{
winerror(_("Error opening connection to the Service Manager."));
exit(-1);
}
SC_HANDLE scm = OpenSCManager(NULL, NULL, SC_MANAGER_CREATE_SERVICE);
if (!scm) {
throw std::runtime_error("Error opening connection to the Service Manager");
}
if(NULL==(service_handle=OpenService(scm,SERVICE_SHORT_NAME,SERVICE_QUERY_STATUS|DELETE)))
{
winerror(_("Error opening service."));
CloseServiceHandle(scm);
exit(-1);
}
std::string exepath = "\"" + std::string(filename) + "\" -service";
SC_HANDLE service_handle = CreateService(
scm, // SCM handle
SERVICE_SHORT_NAME, // Service name
SERVICE_NAME, // Display name
SERVICE_ALL_ACCESS, // Desired access
SERVICE_WIN32_OWN_PROCESS, // Service type
SERVICE_AUTO_START, // Start type
SERVICE_ERROR_NORMAL, // Error control
exepath.c_str(), // Executable path
NULL, NULL, NULL, NULL, NULL // Other parameters
);
if (!service_handle) {
CloseServiceHandle(scm);
throw std::runtime_error("Error creating service. Already installed?");
}
// Set service description
SERVICE_DESCRIPTION service_description = {const_cast<LPSTR>(SERVICE_DESCRIPTION_TEXT)};
ChangeServiceConfig2(service_handle, SERVICE_CONFIG_DESCRIPTION, &service_description);
if(0==QueryServiceStatus(service_handle,&status))
{
winerror(_("Error querying service."));
CloseServiceHandle(scm);
CloseServiceHandle(service_handle);
exit(-1);
}
if(SERVICE_STOPPED!=status.dwCurrentState)
{
winerror(SERVICE_NAME _(" is still running. Stop it before trying to uninstall it."));
CloseServiceHandle(scm);
CloseServiceHandle(service_handle);
exit(-1);
}
show_message("Service successfully installed");
return 0;
}
if(0==DeleteService(service_handle))
{
winerror(_("Error deleting service."));
static int service_uninstall() {
SC_HANDLE scm = OpenSCManager(NULL, NULL, SC_MANAGER_CREATE_SERVICE);
if (!scm) {
throw std::runtime_error("Error opening connection to the Service Manager");
}
SC_HANDLE service_handle = OpenService(scm, SERVICE_SHORT_NAME, SERVICE_QUERY_STATUS | DELETE);
if (!service_handle) {
CloseServiceHandle(scm);
throw std::runtime_error("Error opening service");
}
SERVICE_STATUS status;
if (!QueryServiceStatus(service_handle, &status)) {
CloseServiceHandle(service_handle);
CloseServiceHandle(scm);
throw std::runtime_error("Error querying service");
}
if (status.dwCurrentState != SERVICE_STOPPED) {
CloseServiceHandle(service_handle);
CloseServiceHandle(scm);
throw std::runtime_error(SERVICE_NAME " is still running. Stop it before trying to uninstall it");
}
if (!DeleteService(service_handle)) {
CloseServiceHandle(service_handle);
CloseServiceHandle(scm);
throw std::runtime_error("Error deleting service");
}
CloseServiceHandle(service_handle);
CloseServiceHandle(scm);
CloseServiceHandle(service_handle);
exit(-1);
}
CloseServiceHandle(scm);
CloseServiceHandle(service_handle);
winmessage(_("Service successfully uninstalled."));
return 0;
show_message("Service successfully uninstalled");
return 0;
}
static void WINAPI service_main(DWORD argc,LPTSTR *argv)
{
char *tmpstr;
static void WINAPI service_main(DWORD argc, LPTSTR* argv) {
service_status = {
.dwServiceType = SERVICE_WIN32,
.dwControlsAccepted = SERVICE_ACCEPT_STOP,
.dwWin32ExitCode = NO_ERROR,
.dwServiceSpecificExitCode = NO_ERROR,
.dwCheckPoint = 0,
.dwWaitHint = 0
};
//configure service_status structure
service_status.dwServiceType=SERVICE_WIN32;
service_status.dwControlsAccepted=0;
service_status.dwWin32ExitCode=NO_ERROR;
service_status.dwServiceSpecificExitCode=NO_ERROR;
service_status.dwCheckPoint=0;
service_status.dwWaitHint=0;
service_status.dwControlsAccepted|=SERVICE_ACCEPT_STOP;
service_status_handle = RegisterServiceCtrlHandler(SERVICE_SHORT_NAME, handler);
if (!service_status_handle) {
return;
}
service_status_handle=RegisterServiceCtrlHandler(SERVICE_SHORT_NAME,handler);
update_service_status(SERVICE_RUNNING);
if(0!=service_status_handle)
{
//set service status
ChangeServiceStatus(service_status_handle,service_status,SERVICE_RUNNING);
try {
Parameters params;
if (GetModuleFileName(NULL, params[1], 1024) == 0) {
throw std::runtime_error("Error getting module filename");
}
//get the path to the config file
init_params();
GetModuleFileName(NULL,params[1],1024);
if(NULL==(tmpstr=strrchr(params[1],'\\'))) { winerror("Error finding default config file."); exit(-1); }
*(++tmpstr)='\0';
strncat(params[1],"hermes.ini",strlen("hermes.ini"));
char* config_path = strrchr(params[1], '\\');
if (!config_path) {
throw std::runtime_error("Error finding default config file");
}
//now start our main program
hermes_main(2,(char **)params);
*(++config_path) = '\0';
strncat(params[1], "hermes.ini", strlen("hermes.ini"));
free_params();
hermes_main(2, params.get());
//when we are here, we have been stopped
ChangeServiceStatus(service_status_handle,service_status,SERVICE_STOP_PENDING);
ChangeServiceStatus(service_status_handle,service_status,SERVICE_STOPPED);
}
update_service_status(SERVICE_STOP_PENDING);
update_service_status(SERVICE_STOPPED);
}
catch (const std::exception& e) {
show_error(e.what());
update_service_status(SERVICE_STOPPED);
}
}
static void WINAPI handler(DWORD code)
{
if(SERVICE_CONTROL_STOP==code)
{
quit=true;
ChangeServiceStatus(service_status_handle,service_status,SERVICE_STOP_PENDING);
}
static void WINAPI handler(DWORD code) {
if (code == SERVICE_CONTROL_STOP) {
quit = true;
update_service_status(SERVICE_STOP_PENDING);
}
}