// =================================================================
//
//  Copyright (C) 2001 Maciej Sobczak
//  Copyright (C) 2003 Alex Vinokur - minor (cosmetic) changes
//
//  For conditions of distribution and use, see
//  copyright notice in common.h
//
// =================================================================


// #################################################################
//
//  SOFTWARE : C++ Stream-Compatible TCP/IP Sockets Demo Application
//  FILE     : sockets.cpp
//
//  DESCRIPTION :
//         The wrapper classes that can be used
//         as a iostream-compatible TCP/IP sockets.
//         Classes implementation (non-template methods)
//
// #################################################################


// ==================
#include "sockets2.h"

#ifndef WIN32
// this is for Linux
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <cerrno>
#include <netdb.h>
#include <arpa/inet.h>
#define INVALID_SOCKET -1
#define SOCKET_ERROR -1
#endif
// ==================


// -------------------------
// --- class SocketException
// -------------------------
// ---------
// Constructor-1
SocketRunTimeException::SocketRunTimeException (const string &what)
	: 
	runtime_error(what)
{
SET_DETAILED_TRACE ("(Ctor) what = " << what);
#ifdef WIN32
  errnum_ = ::WSAGetLastError();
#else
  errnum_ = errno;
#endif
}

// ---------
// Destructor
SocketRunTimeException::~SocketRunTimeException() throw()
{
SET_DTOR_TRACE;
}

// ---------
int SocketRunTimeException::errornumber() const throw()
{
SET_TRACE;
  return errnum_;
}

// ---------
const char * SocketRunTimeException::what() const throw()
{
SET_TRACE;
ostringstream ss;
  ss << runtime_error::what();
  ss << " error number: " << errnum_;
  msg_ = ss.str();
  return msg_.c_str();
}



// ------------------------------
// --- class SocketLogicException
// ------------------------------
// ---------
// Constructor-1
SocketLogicException::SocketLogicException (const string &what)
	:
	logic_error(what)
{
SET_DETAILED_TRACE ("what = " << what);
}



// --------------------------------------------
// -- class TCPSocketWrapper::TCPAcceptedSocket
// --------------------------------------------
// ----------------------------
// Constructor-1
TCPSocketWrapper::TCPAcceptedSocket::TCPAcceptedSocket
	(socket_type s, sockaddr_in a)
	: 
	sock_(s), 
	addr_(a)
{
SET_TRACE;
}

// ----------------------------
// Copy Constructor
TCPSocketWrapper::TCPAcceptedSocket::TCPAcceptedSocket(const TCPAcceptedSocket &a)
	:
	sock_(a.sock_), 
	addr_(a.addr_)
{
SET_TRACE;
}



// --------------------------
// --- class TCPSocketWrapper
// --------------------------
// ----------------------------
// Constructor-0
TCPSocketWrapper::TCPSocketWrapper()
	: 
	sockstate_(CLOSED)
{
SET_TRACE;
}

// ----------------------------
// Constructor-1
TCPSocketWrapper::TCPSocketWrapper (
	const TCPSocketWrapper::TCPAcceptedSocket &as_i
	)
	: 
	sock_(as_i.sock_), 
	sockaddress_(as_i.addr_), 
	sockstate_(ACCEPTED)
{
SET_CTOR_TRACE;
}

// ----------------------------
// Destructor
TCPSocketWrapper::~TCPSocketWrapper()
{
SET_DTOR_TRACE;
  if (sockstate_ != CLOSED)
  {
#ifdef WIN32
    closesocket(sock_);
#else
    ::close(sock_);
#endif
  }
}

// ----------------------------
void TCPSocketWrapper::listen(int port, int backlog)
{
SET_DETAILED_TRACE ("Port = " << port << ";  backlog = " << backlog);
  if (sockstate_ != CLOSED)
  {
    //throw SocketLogicException("socket not in CLOSED state");
    throw SocketLogicException (MSG_THROW("socket not in CLOSED state"));
  }

  sock_ = socket(AF_INET, SOCK_STREAM, 0);
  if (sock_ == INVALID_SOCKET)
  {
    throw SocketRunTimeException(MSG_THROW("socket failed"));
  }

  sockaddr_in local;

  memset(&local, 0, sizeof(local));

  local.sin_family = AF_INET;
  local.sin_port = htons((u_short)port);
  local.sin_addr.s_addr = htonl(INADDR_ANY);

  if (::bind (sock_, (sockaddr*)&local, sizeof(local)) == SOCKET_ERROR)
  {
    throw SocketRunTimeException(MSG_THROW("bind failed"));	
  }

  if (::listen(sock_, backlog) == SOCKET_ERROR)
  {
    throw SocketRunTimeException(MSG_THROW("listen failed"));
  }

  memset(&sockaddress_, 0, sizeof(sockaddress_));
  sockstate_ = LISTENING;

}


// ----------------------------
TCPSocketWrapper::TCPAcceptedSocket TCPSocketWrapper::accept()
{
SET_TRACE;
  if (sockstate_ != LISTENING)
  {
    throw SocketLogicException(MSG_THROW("socket not listening"));
  }

  sockaddr_in from;
  socklen_t len = sizeof(from);

  memset(&from, 0, len);

  socket_type newsocket = ::accept(sock_, (sockaddr*)&from, &len);
  if (newsocket == INVALID_SOCKET)
  {
    throw SocketRunTimeException(MSG_THROW("accept failed"));
  }

  return TCPAcceptedSocket(newsocket, from);
}


// ----------------------------
sockaddr_in TCPSocketWrapper::get_sockaddress () const
{
SET_TRACE;
  return sockaddress_;
}



// ----------------------------
void TCPSocketWrapper::connect(const char *address, int port)
{
SET_DETAILED_TRACE ("Address = " << address << ";  Port = " << port);

ostringstream oss_ip;
  oss_ip << "Server IP Address = "
         << address
         << ", Port No = "
         << port;

  if (sockstate_ != CLOSED)
  {
    ostringstream oss;
    oss << "socket not in CLOSED state"; 

    cout << "\t" << string (oss_ip.str().size(), '#') << endl;
    cout << "\t" << oss.str() << endl;
    cout << "\t" << oss_ip.str() << endl;
    cout << "\t" << string (oss_ip.str().size(), '#') << endl;

    throw SocketLogicException(MSG_THROW("socket not in CLOSED state"));
  }

  sock_ = socket(AF_INET, SOCK_STREAM, 0);
  if (sock_ == INVALID_SOCKET)
  {
    ostringstream oss;
    oss << "socket failed"; 

    cout << "\t" << string (oss_ip.str().size(), '#') << endl;
    cout << "\t" << oss.str() << endl;
    cout << "\t" << oss_ip.str() << endl;
    cout << "\t" << string (oss_ip.str().size(), '#') << endl;

    throw SocketRunTimeException(MSG_THROW("socket failed"));
  }

  hostent *hp;

  unsigned long addr = inet_addr(address);
  if (addr != INADDR_NONE)
  {
    hp = gethostbyaddr((const char*)&addr, 4, AF_INET);
  }
  else
  {
    hp = gethostbyname(address);
  }

  if (hp == NULL)
  {
    ostringstream oss;
    oss << "cannot resolve address"; 

    cout << "\t" << string (oss_ip.str().size(), '#') << endl;
    cout << "\t" << oss.str() << endl;
    cout << "\t" << oss_ip.str() << endl;
    cout << "\t" << string (oss_ip.str().size(), '#') << endl;

    throw SocketRunTimeException(MSG_THROW("cannot resolve address"));
  }

  if (hp->h_addrtype != AF_INET)
  {
    ostringstream oss;
    oss << "address resolved with TCP incompatible type"; 

    cout << "\t" << string (oss_ip.str().size(), '#') << endl;
    cout << "\t" << oss.str() << endl;
    cout << "\t" << oss_ip.str() << endl;
    cout << "\t" << string (oss_ip.str().size(), '#') << endl;

    throw SocketRunTimeException
	(MSG_THROW("address resolved with TCP incompatible type"));
  }

  memset(&sockaddress_, 0, sizeof(sockaddress_));
  memcpy(&(sockaddress_.sin_addr), hp->h_addr_list[0], hp->h_length);
  sockaddress_.sin_family = AF_INET;
  sockaddress_.sin_port = htons((u_short)port);

  if (::connect(sock_, (sockaddr*)&sockaddress_, sizeof(sockaddress_)) == SOCKET_ERROR)
  {
    ostringstream oss;
    oss <<"connect failed"; 

    cout << "\t" << string (oss_ip.str().size(), '#') << endl;
    cout << "\t" << oss.str() << endl;
    cout << "\t" << oss_ip.str() << endl;
    cout << "\t" << string (oss_ip.str().size(), '#') << endl;
    
    throw SocketRunTimeException(MSG_THROW("connect failed"));
  }

  sockstate_ = CONNECTED;
}

// ----------------------------
const char * TCPSocketWrapper::address() const
{
SET_TRACE;
  if (sockstate_ != CONNECTED && sockstate_ != ACCEPTED)
  {
    throw SocketLogicException(MSG_THROW("socket not connected"));
  }

  return inet_ntoa(sockaddress_.sin_addr);
}

// ----------------------------
int TCPSocketWrapper::port() const
{
SET_TRACE;
  if (sockstate_ != CONNECTED && sockstate_ != ACCEPTED)
  {
    throw SocketLogicException(MSG_THROW("socket not connected"));
  }

  return ntohs(sockaddress_.sin_port);
}

// ----------------------------
void TCPSocketWrapper::write(const void *buf, size_t len)
{
#if (TRACE_LOG == 1)
  if (buf)
  {
    SET_DETAILED_TRACE ("buf_len = " << len << ", buf = <" << string ((char*)buf).substr(0, ((len > 0) ? (len - 1) : 0)) << ">");
  }
  else
  {
    SET_DETAILED_TRACE ("buf_len = " << len << ", buf (pointer) = NULL");
  }
#endif

  // ------------------

  if (sockstate_ != CONNECTED && sockstate_ != ACCEPTED)
  {
    throw SocketLogicException(MSG_THROW("socket not connected"));
  }

  int written;
  while (len != 0)
  {
    if ((written = send(sock_, (const char*)buf, (int)len, 0)) == SOCKET_ERROR)
    {
      throw SocketRunTimeException(MSG_THROW("write failed"));
    }

    len -= written;
    buf = (const char*)buf + written;
  }
}

// ----------------------------
size_t TCPSocketWrapper::read(void *buf, size_t len)
{
SET_TRACE;
  if (sockstate_ != CONNECTED && sockstate_ != ACCEPTED)
  {
    throw SocketLogicException(MSG_THROW("socket not connected"));
  }

SHOW1_TRACE ("BEFORE recv");

  int readn = recv(sock_, (char*)buf, (int)len, 0);

  assert (readn < (int)len);

  if (readn == SOCKET_ERROR)
  {
    throw SocketRunTimeException(MSG_THROW("read failed"));
  }

  SHOW2_TRACE ("AFTER  recv", 
	"buf_len = " + 
	to_string (len) + 
	"; read_n = " + 
	to_string (readn) + 
	", buf = <" + 
	string ((char*)buf).substr(0, ((readn > 0) ? (readn - 1) : 0)) +
	">"
	);

  // ------------------

  return (size_t)readn;
}


// ----------------------------
TCPSocketWrapper::sockstate_type TCPSocketWrapper::state() const
{
SET_TRACE;
  return sockstate_; 
}

void TCPSocketWrapper::close()
{
SET_TRACE;
  if (sockstate_ != CLOSED)
  {
#ifdef WIN32
    if (closesocket(sock_) == SOCKET_ERROR)
    {
      throw SocketRunTimeException(MSG_THROW("close failed"));
    }
#else
    if (::close(sock_))
    {
      throw SocketRunTimeException("MSG_THROW(close failed"));
    }
#endif
    sockstate_ = CLOSED;
  }
}


// ----------------------------
// ----------------------------
bool socketsInit()
{
SET_TRACE;

#ifdef WIN32
WSADATA wsadata;
  if (WSAStartup(MAKEWORD(2, 0), &wsadata) == 0)
  {
    return true;
  }
  else
  {
    return false;
  }
#else
  // Linux/Unix do not require any initialization
  return true;
#endif
}

// ----------------------------
void socketsEnd()
{
SET_TRACE;
#ifdef WIN32
  // we do not care about the error codes
  // anyway, we end the program
  WSACleanup();
#endif
}
