join 1.0
lightweight network framework library
Loading...
Searching...
No Matches
socket.hpp
Go to the documentation of this file.
1
25#ifndef __JOIN_SOCKET_HPP__
26#define __JOIN_SOCKET_HPP__
27
28// libjoin.
29#include <join/protocol.hpp>
30#include <join/endpoint.hpp>
31#include <join/openssl.hpp>
32#include <join/reactor.hpp>
33#include <join/utils.hpp>
34#include <join/error.hpp>
35
36// Libraries.
37#include <openssl/err.h>
38
39// C++.
40#include <type_traits>
41#include <iostream>
42
43// C.
44#include <netinet/tcp.h>
45#include <linux/icmp.h>
46#include <sys/ioctl.h>
47#include <sys/stat.h>
48#include <fnmatch.h>
49#include <cassert>
50#include <fcntl.h>
51#include <poll.h>
52
53namespace join
54{
58 template <class Protocol>
60 {
61 public:
62 using Ptr = std::unique_ptr <BasicSocket <Protocol>>;
63 using Endpoint = typename Protocol::Endpoint;
64
68 enum Mode
69 {
72 };
73
97
109
115 {
116 }
117
123 : _mode (mode)
124 {
125 }
126
131 BasicSocket (const BasicSocket& other) = delete;
132
138 BasicSocket& operator= (const BasicSocket& other) = delete;
139
145 : _state (other._state),
146 _mode (other._mode),
147 _handle (other._handle),
148 _protocol (other._protocol)
149 {
150 other._state = State::Closed;
151 other._mode = Mode::NonBlocking;
152 other._handle = -1;
153 other._protocol = Protocol ();
154 }
155
162 {
163 this->close ();
164
165 this->_state = other._state;
166 this->_mode = other._mode;
167 this->_handle = other._handle;
168 this->_protocol = other._protocol;
169
170 other._state = State::Closed;
171 other._mode = Mode::NonBlocking;
172 other._handle = -1;
173 other._protocol = Protocol ();
174
175 return *this;
176 }
177
181 virtual ~BasicSocket ()
182 {
183 if (this->_handle != -1)
184 {
185 ::close (this->_handle);
186 }
187 }
188
194 virtual int open (const Protocol& protocol = Protocol ()) noexcept
195 {
196 if (this->_state != State::Closed)
197 {
198 lastError = make_error_code (Errc::InUse);
199 return -1;
200 }
201
202 if (this->_mode == Mode::NonBlocking)
203 this->_handle = ::socket (protocol.family (), protocol.type () | SOCK_NONBLOCK, protocol.protocol ());
204 else
205 this->_handle = ::socket (protocol.family (), protocol.type (), protocol.protocol ());
206
207 if (this->_handle == -1)
208 {
209 lastError = std::make_error_code (static_cast <std::errc> (errno));
210 this->close ();
211 return -1;
212 }
213
215 this->_protocol = protocol;
216
217 return 0;
218 }
219
223 virtual void close () noexcept
224 {
225 if (this->_state != State::Closed)
226 {
227 ::close (this->_handle);
228 this->_state = State::Closed;
229 this->_handle = -1;
230 }
231 }
232
238 virtual int bind (const Endpoint& endpoint) noexcept
239 {
240 if ((this->_state == State::Closed) && (this->open (endpoint.protocol ()) == -1))
241 {
242 return -1;
243 }
244
245 if (endpoint.protocol ().family () == AF_PACKET)
246 {
247 if (reinterpret_cast <const struct sockaddr_ll*> (endpoint.addr ())->sll_ifindex == 0)
248 {
249 lastError = std::make_error_code (std::errc::no_such_device);
250 return -1;
251 }
252 }
253 else if ((endpoint.protocol ().family () == AF_INET6) || (endpoint.protocol ().family () == AF_INET))
254 {
255 this->setOption (Option::ReuseAddr, 1);
256 }
257 else if (endpoint.protocol ().family () == AF_UNIX)
258 {
259 ::unlink (endpoint.device ().c_str ());
260 }
261
262 if (::bind (this->_handle, endpoint.addr (), endpoint.length ()) == -1)
263 {
264 lastError = std::make_error_code (static_cast <std::errc> (errno));
265 return -1;
266 }
267
268 return 0;
269 }
270
275 virtual int canRead () const noexcept
276 {
277 int available = 0;
278
279 // check if data can be read in the socket internal buffer.
280 if (::ioctl (this->_handle, FIONREAD, &available) == -1)
281 {
282 lastError = std::make_error_code (static_cast <std::errc> (errno));
283 return -1;
284 }
285
286 return available;
287 }
288
294 virtual bool waitReadyRead (int timeout = 0) const noexcept
295 {
296 return (this->wait (true, false, timeout) == 0);
297 }
298
305 virtual int read (char *data, unsigned long maxSize) noexcept
306 {
307 struct iovec iov;
308 iov.iov_base = data;
309 iov.iov_len = maxSize;
310
311 struct msghdr message;
312 message.msg_name = nullptr;
313 message.msg_namelen = 0;
314 message.msg_iov = &iov;
315 message.msg_iovlen = 1;
316 message.msg_control = nullptr;
317 message.msg_controllen = 0;
318
319 int size = ::recvmsg (this->_handle, &message, 0);
320 if (size < 1)
321 {
322 if (size == -1)
323 {
324 lastError = std::make_error_code (static_cast <std::errc> (errno));
325 }
326 else
327 {
329 }
330 return -1;
331 }
332
333 return size;
334 }
335
341 virtual bool waitReadyWrite (int timeout = 0) const noexcept
342 {
343 return (this->wait (false, true, timeout) == 0);
344 }
345
352 virtual int write (const char *data, unsigned long maxSize) noexcept
353 {
354 struct iovec iov;
355 iov.iov_base = const_cast <char *> (data);
356 iov.iov_len = maxSize;
357
358 struct msghdr message;
359 message.msg_name = nullptr;
360 message.msg_namelen = 0;
361 message.msg_iov = &iov;
362 message.msg_iovlen = 1;
363 message.msg_control = nullptr;
364 message.msg_controllen = 0;
365
366 int result = ::sendmsg (this->_handle, &message, 0);
367 if (result == -1)
368 {
369 lastError = std::make_error_code (static_cast <std::errc> (errno));
370 return -1;
371 }
372
373 return result;
374 }
375
380 void setMode (Mode mode) noexcept
381 {
382 this->_mode = mode;
383
384 if (this->_state != State::Closed)
385 {
386 int flags = ::fcntl (this->_handle, F_GETFL, 0);
387
388 if (this->_mode == Mode::NonBlocking)
389 {
390 flags = flags | O_NONBLOCK;
391 }
392 else
393 {
394 flags = flags & ~O_NONBLOCK;
395 }
396
397 ::fcntl (this->_handle, F_SETFL, flags);
398 }
399 }
400
407 virtual int setOption (Option option, int value) noexcept
408 {
409 int optlevel, optname;
410
411 switch (option)
412 {
414 optlevel = SOL_SOCKET;
415 optname = SO_KEEPALIVE;
416 break;
417
419 optlevel = SOL_SOCKET;
420 optname = SO_SNDBUF;
421 break;
422
424 optlevel = SOL_SOCKET;
425 optname = SO_RCVBUF;
426 break;
427
429 optlevel = SOL_SOCKET;
430 optname = SO_TIMESTAMP;
431 break;
432
434 optlevel = SOL_SOCKET;
435 optname = SO_REUSEADDR;
436 break;
437
439 optlevel = SOL_SOCKET;
440 optname = SO_REUSEPORT;
441 break;
442
444 optlevel = SOL_SOCKET;
445 optname = SO_BROADCAST;
446 break;
447
448 case Option::AuxData:
449 optlevel = SOL_PACKET;
450 optname = PACKET_AUXDATA;
451 break;
452
453 default:
455 return -1;
456 }
457
458 int result = ::setsockopt (this->_handle, optlevel, optname, &value, sizeof (value));
459 if (result == -1)
460 {
461 lastError = std::make_error_code (static_cast <std::errc> (errno));
462 return -1;
463 }
464
465 return 0;
466 }
467
473 {
474 struct sockaddr_storage sa;
475 socklen_t sa_len = sizeof (struct sockaddr_storage);
476
477 if (::getsockname (this->_handle, reinterpret_cast <struct sockaddr*> (&sa), &sa_len) == -1)
478 {
479 return {};
480 }
481
482 return Endpoint (reinterpret_cast <struct sockaddr*> (&sa), sa_len);
483 }
484
489 bool opened () const noexcept
490 {
491 return (this->_state != State::Closed);
492 }
493
498 virtual bool encrypted () const noexcept
499 {
500 return false;
501 }
502
507 int family () const noexcept
508 {
509 return this->_protocol.family ();
510 }
511
516 int type () const noexcept
517 {
518 return this->_protocol.type ();
519 }
520
525 int protocol () const noexcept
526 {
527 return this->_protocol.protocol ();
528 }
529
534 int handle () const noexcept override
535 {
536 return this->_handle;
537 }
538
546 static uint16_t checksum (const uint16_t* data, size_t len, uint16_t current = 0)
547 {
548 uint32_t sum = current;
549
550 while (len > 1)
551 {
552 sum += *data++;
553 len -= 2;
554 }
555
556 if (len == 1)
557 {
558 #if __BYTE_ORDER == __LITTLE_ENDIAN
559 sum += *reinterpret_cast <const uint8_t *> (data);
560 #else
561 sum += *reinterpret_cast <const uint8_t *> (data) << 8;
562 #endif
563 }
564
565 sum = (sum >> 16) + (sum & 0xffff);
566 sum += (sum >> 16);
567
568 return static_cast <uint16_t> (~sum);
569 }
570
571 protected:
579 int wait (bool wantRead, bool wantWrite, int timeout) const noexcept
580 {
581 struct pollfd handle { .fd = this->_handle, .events = 0, .revents = 0 };
582
583 if (wantRead)
584 {
585 handle.events |= POLLIN;
586 }
587
588 if (wantWrite)
589 {
590 handle.events |= POLLOUT;
591 }
592
593 int nset = (handle.fd > -1) ? ::poll (&handle, 1, timeout) : -1;
594 if (nset != 1)
595 {
596 if (nset == -1)
597 {
598 if (handle.fd == -1)
599 {
600 errno = EBADF;
601 }
602 lastError = std::make_error_code (static_cast <std::errc> (errno));
603 }
604 else
605 {
606 lastError = make_error_code (Errc::TimedOut);
607 }
608
609 return -1;
610 }
611
612 return 0;
613 }
614
617
620
622 int _handle = -1;
623
625 Protocol _protocol;
626 };
627
634 template <class Protocol>
635 constexpr bool operator< (const BasicSocket <Protocol>& a, const BasicSocket <Protocol>& b) noexcept
636 {
637 return a.handle () < b.handle ();
638 }
639
643 template <class Protocol>
644 class BasicDatagramSocket : public BasicSocket <Protocol>
645 {
646 public:
647 using Ptr = std::unique_ptr <BasicDatagramSocket <Protocol>>;
651 using Endpoint = typename Protocol::Endpoint;
652
660
665 BasicDatagramSocket (Mode mode, int ttl = 60)
666 : BasicSocket <Protocol> (mode),
667 _ttl (ttl)
668 {
669 }
670
676
683
689 : BasicSocket <Protocol> (std::move (other)),
690 _remote (std::move (other._remote)),
691 _ttl (other._ttl)
692 {
693 other._ttl = 60;
694 }
695
702 {
703 BasicSocket <Protocol>::operator= (std::move (other));
704
705 _remote = std::move (other._remote);
706 _ttl = other._ttl;
707
708 other._ttl = 60;
709
710 return *this;
711 }
712
716 virtual ~BasicDatagramSocket () = default;
717
723 virtual int open (const Protocol& protocol = Protocol ()) noexcept override
724 {
726 if (result == -1)
727 {
728 return -1;
729 }
730
731 int off = 0;
732
733 if ((protocol.protocol () == IPPROTO_UDP) || (protocol.protocol () == IPPROTO_TCP))
734 {
735 if ((protocol.family () == AF_INET6) && (::setsockopt (this->_handle, IPPROTO_IPV6, IPV6_V6ONLY, &off, sizeof (off)) == -1))
736 {
737 lastError = std::make_error_code (static_cast <std::errc> (errno));
738 this->close ();
739 return -1;
740 }
741 }
742
743 if ((protocol.protocol () == IPPROTO_ICMPV6) || (protocol.protocol () == IPPROTO_ICMP))
744 {
745 if ((protocol.family () == AF_INET) && (::setsockopt (this->_handle, IPPROTO_IP, IP_HDRINCL, &off, sizeof (off)) == -1))
746 {
747 lastError = std::make_error_code (static_cast <std::errc> (errno));
748 this->close ();
749 return -1;
750 }
751
752 this->setOption (Option::MulticastTtl, this->_ttl);
753 this->setOption (Option::Ttl, this->_ttl);
754 }
755
756 return 0;
757 }
758
764 virtual int bindToDevice (const std::string& device) noexcept
765 {
766 if (this->_state == State::Closed)
767 {
769 return -1;
770 }
771
772 if (this->_state == State::Connected)
773 {
774 lastError = make_error_code (Errc::InUse);
775 return -1;
776 }
777
778 if ((this->_protocol.family () == AF_INET6) || (this->_protocol.family () == AF_INET))
779 {
780 this->setOption (Option::ReuseAddr, 1);
781 }
782
783 int result = setsockopt (this->_handle, SOL_SOCKET, SO_BINDTODEVICE, device.c_str (), device.size ());
784 if (result == -1)
785 {
786 lastError = std::make_error_code (static_cast <std::errc> (errno));
787 return -1;
788 }
789
790 return 0;
791 }
792
798 virtual int connect (const Endpoint& endpoint)
799 {
800 if ((this->_state != State::Closed) && (this->_state != State::Disconnected))
801 {
802 lastError = make_error_code (Errc::InUse);
803 return -1;
804 }
805
806 if ((this->_state == State::Closed) && (this->open (endpoint.protocol ()) == -1))
807 {
808 return -1;
809 }
810
811 int result = ::connect (this->_handle, endpoint.addr (), endpoint.length ());
812
813 this->_state = State::Connecting;
814 this->_remote = endpoint;
815
816 if (result == -1)
817 {
818 lastError = std::make_error_code (static_cast <std::errc> (errno));
819 if (lastError != std::errc::operation_in_progress)
820 {
821 this->close ();
822 }
823 return -1;
824 }
825
826 this->_state = State::Connected;
827
828 return 0;
829 }
830
835 virtual int disconnect ()
836 {
837 if (this->_state == State::Connected)
838 {
839 struct sockaddr_storage nullAddr;
840 ::memset (&nullAddr, 0, sizeof (nullAddr));
841
842 nullAddr.ss_family = AF_UNSPEC;
843
844 int result = ::connect (this->_handle, reinterpret_cast <struct sockaddr*> (&nullAddr), sizeof (struct sockaddr_storage));
845 if (result == -1)
846 {
847 if (errno != EAFNOSUPPORT)
848 {
849 lastError = std::make_error_code (static_cast <std::errc> (errno));
850 return -1;
851 }
852 }
853
854 this->_state = State::Disconnected;
855 this->_remote = {};
856 }
857
858 return 0;
859 }
860
864 virtual void close () noexcept override
865 {
867 this->_remote = {};
868 }
869
876 virtual int read (char *data, unsigned long maxSize) noexcept override
877 {
878 return BasicSocket <Protocol>::read (data, maxSize);
879 }
880
888 virtual int readFrom (char* data, unsigned long maxSize, Endpoint* endpoint = nullptr) noexcept
889 {
890 struct sockaddr_storage sa;
891 socklen_t sa_len = sizeof (struct sockaddr_storage);
892
893 int size = ::recvfrom (this->_handle, data, maxSize, 0, reinterpret_cast <struct sockaddr*> (&sa), &sa_len);
894 if (size < 1)
895 {
896 if (size == -1)
897 {
898 lastError = std::make_error_code (static_cast <std::errc> (errno));
899 }
900 else
901 {
903 this->_state = State::Disconnected;
904 }
905
906 return -1;
907 }
908
909 if (endpoint != nullptr)
910 {
911 *endpoint = Endpoint (reinterpret_cast <struct sockaddr*> (&sa), sa_len);
912 }
913
914 return size;
915 }
916
923 virtual int write (const char *data, unsigned long maxSize) noexcept override
924 {
925 return BasicSocket <Protocol>::write (data, maxSize);
926 }
927
935 virtual int writeTo (const char* data, unsigned long maxSize, const Endpoint& endpoint) noexcept
936 {
937 if ((this->_state == State::Closed) && (this->open (endpoint.protocol ()) == -1))
938 {
939 return -1;
940 }
941
942 int result = ::sendto (this->_handle, data, maxSize, 0, endpoint.addr (), endpoint.length ());
943 if (result < 0)
944 {
945 lastError = std::make_error_code (static_cast <std::errc> (errno));
946 return -1;
947 }
948
949 return result;
950 }
951
958 virtual int setOption (Option option, int value) noexcept override
959 {
960 if (this->_state == State::Closed)
961 {
963 return -1;
964 }
965
966 int optlevel, optname;
967
968 switch (option)
969 {
970 case Option::Ttl:
971 if (this->family () == AF_INET6)
972 {
973 optlevel = IPPROTO_IPV6;
974 optname = IPV6_UNICAST_HOPS;
975 }
976 else
977 {
978 optlevel = IPPROTO_IP;
979 optname = IP_TTL;
980 }
981 break;
982
983 case Option::MulticastLoop:
984 if (this->family () == AF_INET6)
985 {
986 optlevel = IPPROTO_IPV6;
987 optname = IPV6_MULTICAST_LOOP;
988 }
989 else
990 {
991 optlevel = IPPROTO_IP;
992 optname = IP_MULTICAST_LOOP;
993 }
994 break;
995
996 case Option::MulticastTtl:
997 if (this->family () == AF_INET6)
998 {
999 optlevel = IPPROTO_IPV6;
1000 optname = IPV6_MULTICAST_HOPS;
1001 }
1002 else
1003 {
1004 optlevel = IPPROTO_IP;
1005 optname = IP_MULTICAST_TTL;
1006 }
1007 break;
1008
1009 case Option::PathMtuDiscover:
1010 if (this->family () == AF_INET6)
1011 {
1012 optlevel = IPPROTO_IPV6;
1013 optname = IPV6_MTU_DISCOVER;
1014 }
1015 else
1016 {
1017 optlevel = IPPROTO_IP;
1018 optname = IP_MTU_DISCOVER;
1019 }
1020 break;
1021
1022 case Option::RcvError:
1023 if (this->family () == AF_INET6)
1024 {
1025 optlevel = IPPROTO_IPV6;
1026 optname = IPV6_RECVERR;
1027 }
1028 else
1029 {
1030 optlevel = IPPROTO_IP;
1031 optname = IP_RECVERR;
1032 }
1033 break;
1034
1035 default:
1036 return BasicSocket<Protocol>::setOption (option, value);
1037 }
1038
1039 int result = ::setsockopt (this->_handle, optlevel, optname, &value, sizeof (value));
1040 if (result == -1)
1041 {
1042 lastError = std::make_error_code (static_cast <std::errc> (errno));
1043 return -1;
1044 }
1045
1046 return 0;
1047 }
1048
1053 const Endpoint& remoteEndpoint () const
1054 {
1055 return this->_remote;
1056 }
1057
1062 virtual bool connected () noexcept
1063 {
1064 return (this->_state == State::Connected);
1065 }
1066
1071 int mtu () const
1072 {
1073 if (this->_state == State::Closed)
1074 {
1076 return -1;
1077 }
1078
1079 int result = -1, value = -1;
1080 socklen_t valueLen = sizeof (value);
1081
1082 if (this->_protocol.family () == AF_INET6)
1083 {
1084 result = ::getsockopt (this->_handle, IPPROTO_IPV6, IPV6_MTU, &value, &valueLen);
1085 }
1086 else if (this->_protocol.family () == AF_INET)
1087 {
1088 result = ::getsockopt (this->_handle, IPPROTO_IP, IP_MTU, &value, &valueLen);
1089 }
1090 else
1091 {
1093 return -1;
1094 }
1095
1096 if (result == -1)
1097 {
1098 lastError = std::make_error_code (static_cast <std::errc> (errno));
1099 return -1;
1100 }
1101
1102 return value;
1103 }
1104
1109 int ttl () const
1110 {
1111 return this->_ttl;
1112 }
1113
1114 protected:
1117
1119 int _ttl = 60;
1120 };
1121
1128 template <class Protocol>
1129 constexpr bool operator< (const BasicDatagramSocket <Protocol>& a, const BasicDatagramSocket <Protocol>& b) noexcept
1130 {
1131 return a.handle () < b.handle ();
1132 }
1133
1137 template <class Protocol>
1138 class BasicStreamSocket : public BasicDatagramSocket <Protocol>
1139 {
1140 public:
1141 using Ptr = std::unique_ptr <BasicStreamSocket <Protocol>>;
1145 using Endpoint = typename Protocol::Endpoint;
1146
1154
1160 : BasicDatagramSocket <Protocol> (mode)
1161 {
1162 }
1163
1168 BasicStreamSocket (const BasicStreamSocket &other) = delete;
1169
1176
1182 : BasicDatagramSocket <Protocol> (std::move (other))
1183 {
1184 }
1185
1192 {
1193 BasicDatagramSocket <Protocol>::operator= (std::move (other));
1194
1195 return *this;
1196 }
1197
1201 virtual ~BasicStreamSocket () = default;
1202
1208 virtual bool waitConnected (int timeout = 0)
1209 {
1210 if (this->_state != State::Connected)
1211 {
1212 if (this->_state != State::Connecting)
1213 {
1215 return false;
1216 }
1217
1218 if (!this->waitReadyWrite (timeout))
1219 {
1220 return false;
1221 }
1222
1223 return connected ();
1224 }
1225
1226 return true;
1227 }
1228
1233 virtual int disconnect () override
1234 {
1235 if (this->_state == State::Connected)
1236 {
1237 ::shutdown (this->_handle, SHUT_WR);
1238 this->_state = State::Disconnecting;
1239 }
1240
1241 if (this->_state == State::Disconnecting)
1242 {
1243 char buffer[4096];
1244 // closing before reading can make the client
1245 // not see all of our output.
1246 // we have to do a "lingering close"
1247 for (;;)
1248 {
1249 int result = this->read (buffer, sizeof (buffer));
1250 if (result <= 0)
1251 {
1252 if ((result == -1) && (lastError == Errc::TemporaryError))
1253 {
1254 return -1;
1255 }
1256
1257 break;
1258 }
1259 }
1260
1261 ::shutdown (this->_handle, SHUT_RD);
1262 this->_state = State::Disconnected;
1263 }
1264
1265 this->close ();
1266
1267 return 0;
1268 }
1269
1275 virtual bool waitDisconnected (int timeout = 0)
1276 {
1277 if ((this->_state != State::Disconnected) && (this->_state != State::Closed))
1278 {
1279 if (this->_state != State::Disconnecting)
1280 {
1282 return false;
1283 }
1284
1285 auto start = std::chrono::steady_clock::now ();
1286 int elapsed = 0;
1287
1288 while ((lastError == Errc::TemporaryError) && (elapsed <= timeout))
1289 {
1290 if (!this->waitReadyRead (timeout - elapsed))
1291 {
1292 return false;
1293 }
1294
1295 if (this->disconnect () == 0)
1296 {
1297 return true;
1298 }
1299
1300 if (timeout)
1301 {
1302 elapsed = std::chrono::duration_cast <std::chrono::milliseconds> (std::chrono::steady_clock::now () - start).count ();
1303 }
1304 }
1305
1306 return false;
1307 }
1308
1309 return true;
1310 }
1311
1319 int readExactly (char *data, unsigned long size, int timeout = 0)
1320 {
1321 unsigned long numRead = 0;
1322
1323 while (numRead < size)
1324 {
1325 int result = this->read (data + numRead, size - numRead);
1326 if (result == -1)
1327 {
1328 if (lastError == Errc::TemporaryError)
1329 {
1330 if (this->waitReadyRead (timeout))
1331 continue;
1332 }
1333
1334 return -1;
1335 }
1336
1337 numRead += result;
1338 }
1339
1340 return 0;
1341 }
1342
1350 int readExactly (std::string& data, unsigned long size, int timeout = 0)
1351 {
1352 data.resize (size);
1353 return readExactly (&data[0], size, timeout);
1354 }
1355
1363 int writeExactly (const char *data, unsigned long size, int timeout = 0)
1364 {
1365 unsigned long numWrite = 0;
1366
1367 while (numWrite < size)
1368 {
1369 int result = this->write (data + numWrite, size - numWrite);
1370 if (result == -1)
1371 {
1372 if (lastError == Errc::TemporaryError)
1373 {
1374 if (this->waitReadyWrite (timeout))
1375 continue;
1376 }
1377
1378 return -1;
1379 }
1380
1381 numWrite += result;
1382 }
1383
1384 return 0;
1385 }
1386
1393 virtual int setOption (Option option, int value) noexcept override
1394 {
1395 if (this->_state == State::Closed)
1396 {
1398 return -1;
1399 }
1400
1401 int optlevel, optname;
1402
1403 switch (option)
1404 {
1405 case Option::NoDelay:
1406 optlevel = IPPROTO_TCP;
1407 optname = TCP_NODELAY;
1408 break;
1409
1410 case Option::KeepIdle:
1411 optlevel = IPPROTO_TCP;
1412 optname = TCP_KEEPIDLE;
1413 break;
1414
1415 case Option::KeepIntvl:
1416 optlevel = IPPROTO_TCP;
1417 optname = TCP_KEEPINTVL;
1418 break;
1419
1420 case Option::KeepCount:
1421 optlevel = IPPROTO_TCP;
1422 optname = TCP_KEEPCNT;
1423 break;
1424
1425 default:
1426 return BasicDatagramSocket <Protocol>::setOption (option, value);
1427 }
1428
1429 int result = ::setsockopt (this->_handle, optlevel, optname, &value, sizeof (value));
1430 if (result == -1)
1431 {
1432 lastError = std::make_error_code (static_cast <std::errc> (errno));
1433 return -1;
1434 }
1435
1436 return 0;
1437 }
1438
1443 virtual bool connecting () const noexcept
1444 {
1445 return (this->_state == State::Connecting);
1446 }
1447
1452 virtual bool connected () noexcept override
1453 {
1454 if (this->_state == State::Connected)
1455 {
1456 return true;
1457 }
1458 else if (this->_state != State::Connecting)
1459 {
1460 return false;
1461 }
1462
1463 int optval;
1464 socklen_t optlen = sizeof (optval);
1465
1466 int result = ::getsockopt (this->_handle, SOL_SOCKET, SO_ERROR, &optval, &optlen);
1467 if ((result == -1) || (optval != 0))
1468 {
1469 return false;
1470 }
1471
1472 this->_state = State::Connected;
1473
1474 return true;
1475 }
1476
1478 friend class BasicStreamAcceptor <Protocol>;
1479 };
1480
1487 template <class Protocol>
1488 constexpr bool operator< (const BasicStreamSocket <Protocol>& a, const BasicStreamSocket <Protocol>& b) noexcept
1489 {
1490 return a.handle () < b.handle ();
1491 }
1492
1496 enum class TlsErrc
1497 {
1500 };
1501
1505 class TlsCategory : public std::error_category
1506 {
1507 public:
1512 virtual const char* name () const noexcept;
1513
1519 virtual std::string message (int code) const;
1520 };
1521
1526 const std::error_category& getTlsCategory ();
1527
1533 std::error_code make_error_code (TlsErrc code);
1534
1540 std::error_condition make_error_condition (TlsErrc code);
1541
1545 template <class Protocol>
1546 class BasicTlsSocket : public BasicStreamSocket <Protocol>
1547 {
1548 public:
1549 using Ptr = std::unique_ptr <BasicTlsSocket <Protocol>>;
1553 using Endpoint = typename Protocol::Endpoint;
1554
1560 {
1561 }
1562
1568 : BasicStreamSocket <Protocol> (mode),
1569 _tlsContext (SSL_CTX_new (TLS_client_method ()))
1570 {
1571 // enable the OpenSSL bug workaround options.
1572 SSL_CTX_set_options (this->_tlsContext.get (), SSL_OP_ALL);
1573
1574 // disallow compression.
1575 SSL_CTX_set_options (this->_tlsContext.get (), SSL_OP_NO_COMPRESSION);
1576
1577 // disallow usage of SSLv2, SSLv3, TLSv1 and TLSv1.1 which are considered insecure.
1578 SSL_CTX_set_options (this->_tlsContext.get (), SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1);
1579
1580 // setup write mode.
1581 SSL_CTX_set_mode (this->_tlsContext.get (), SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
1582
1583 // automatically renegotiates.
1584 SSL_CTX_set_mode (this->_tlsContext.get (), SSL_MODE_AUTO_RETRY);
1585
1586 // set session cache mode to client by default.
1587 SSL_CTX_set_session_cache_mode (this->_tlsContext.get (), SSL_SESS_CACHE_CLIENT);
1588
1589 // no verification by default.
1590 SSL_CTX_set_verify (this->_tlsContext.get (), SSL_VERIFY_NONE, nullptr);
1591
1592 // set default TLSv1.2 and below cipher suites.
1593 SSL_CTX_set_cipher_list (this->_tlsContext.get (), join::defaultCipher.c_str ());
1594
1595 // set default TLSv1.3 cipher suites.
1596 SSL_CTX_set_ciphersuites (this->_tlsContext.get (), join::defaultCipher_1_3.c_str ());
1597 }
1598
1604 : BasicTlsSocket (Mode::NonBlocking, std::move (tlsContext))
1605 {
1606 }
1607
1614 : BasicStreamSocket <Protocol> (mode),
1615 _tlsContext (std::move (tlsContext))
1616 {
1617 if (this->_tlsContext == nullptr)
1618 {
1619 throw std::invalid_argument ("OpenSSL context is invalid");
1620 }
1621 }
1622
1627 BasicTlsSocket (const BasicTlsSocket &other) = delete;
1628
1635
1641 : BasicStreamSocket <Protocol> (std::move (other)),
1642 _tlsContext (std::move (other._tlsContext)),
1643 _tlsHandle (std::move (other._tlsHandle)),
1644 _tlsState (other._tlsState)
1645 {
1646 if (this->_tlsHandle)
1647 {
1648 SSL_set_app_data (this->_tlsHandle.get (), this);
1649 }
1650
1651 other._tlsState = TlsState::NonEncrypted;
1652 }
1653
1660 {
1661 BasicStreamSocket <Protocol>::operator= (std::move (other));
1662
1663 this->_tlsContext = std::move (other._tlsContext);
1664 this->_tlsHandle = std::move (other._tlsHandle);
1665 this->_tlsState = other._tlsState;
1666
1667 if (this->_tlsHandle)
1668 {
1669 SSL_set_app_data (this->_tlsHandle.get (), this);
1670 }
1671
1672 other._tlsState = TlsState::NonEncrypted;
1673
1674 return *this;
1675 }
1676
1680 virtual ~BasicTlsSocket () = default;
1681
1687 int connectEncrypted (const Endpoint& endpoint)
1688 {
1689 if (BasicStreamSocket <Protocol>::connect (endpoint) == -1)
1690 {
1691 return -1;
1692 }
1693
1694 if (this->startEncryption () == -1)
1695 {
1696 this->close ();
1697 return -1;
1698 }
1699
1700 return 0;
1701 }
1702
1708 {
1709 if (this->encrypted () == false)
1710 {
1711 this->_tlsHandle.reset (SSL_new (this->_tlsContext.get ()));
1712 if (this->_tlsHandle == nullptr)
1713 {
1714 lastError = make_error_code (Errc::OutOfMemory);
1715 return -1;
1716 }
1717
1718 if (SSL_set_fd (this->_tlsHandle.get (), this->_handle) != 1)
1719 {
1720 lastError = make_error_code (Errc::InvalidParam);
1721 this->_tlsHandle.reset ();
1722 return -1;
1723 }
1724
1725 if (SSL_is_server (this->_tlsHandle.get ()) == 0)
1726 {
1727 if (!this->_remote.hostname ().empty () && (SSL_set_tlsext_host_name (this->_tlsHandle.get (), this->_remote.hostname ().c_str ()) != 1))
1728 {
1729 lastError = make_error_code (Errc::InvalidParam);
1730 this->_tlsHandle.reset ();
1731 return -1;
1732 }
1733
1734 SSL_set_connect_state (this->_tlsHandle.get ());
1735 }
1736 else
1737 {
1738 SSL_set_accept_state (this->_tlsHandle.get ());
1739 }
1740
1741 SSL_set_app_data (this->_tlsHandle.get (), this);
1742
1743 #ifdef DEBUG
1744 SSL_set_info_callback (this->_tlsHandle.get (), infoWrapper);
1745 #endif
1746
1747 return startHandshake ();
1748 }
1749
1750 return 0;
1751 }
1752
1758 bool waitEncrypted (int timeout = 0)
1759 {
1760 if (this->encrypted () == false)
1761 {
1762 if (this->_state == State::Connecting)
1763 {
1764 if (!this->waitConnected (timeout))
1765 {
1766 return false;
1767 }
1768
1769 if (this->startEncryption () == 0)
1770 {
1771 return true;
1772 }
1773 }
1774
1775 while ((lastError == Errc::TemporaryError) && (SSL_want_read (this->_tlsHandle.get ()) || SSL_want_write (this->_tlsHandle.get ())))
1776 {
1777 if (this->wait (SSL_want_read (this->_tlsHandle.get ()), SSL_want_write (this->_tlsHandle.get ()), timeout) == -1)
1778 {
1779 return false;
1780 }
1781
1782 if (this->startHandshake () == 0)
1783 {
1784 return true;
1785 }
1786 }
1787
1788 return false;
1789 }
1790
1791 return true;
1792 }
1793
1798 virtual int disconnect () override
1799 {
1800 if (this->encrypted ())
1801 {
1802 // check if the close_notify alert was already sent.
1803 if ((SSL_get_shutdown (this->_tlsHandle.get ()) & SSL_SENT_SHUTDOWN) == false)
1804 {
1805 // send the close_notify alert to the peer.
1806 int result = SSL_shutdown (this->_tlsHandle.get ());
1807 if (result < 0)
1808 {
1809 // shutdown was not successful.
1810 switch (SSL_get_error (this->_tlsHandle.get (), result))
1811 {
1812 case SSL_ERROR_WANT_READ:
1813 case SSL_ERROR_WANT_WRITE:
1814 // SSL_shutdown want read or want write.
1816 break;
1817 case SSL_ERROR_SYSCALL:
1818 // an error occurred at the socket level.
1819 switch (errno)
1820 {
1821 case 0:
1822 case ECONNRESET:
1823 case EPIPE:
1826 this->_state = State::Disconnected;
1827 break;
1828 default:
1829 lastError = std::make_error_code (static_cast <std::errc> (errno));
1830 break;
1831 }
1832 break;
1833 default:
1834 // SSL protocol error.
1835 #ifdef DEBUG
1836 std::cout << ERR_reason_error_string (ERR_get_error ()) << std::endl;
1837 #endif
1839 break;
1840 }
1841
1842 return -1;
1843 }
1844 else if (result == 1)
1845 {
1846 // shutdown was successfully completed.
1847 // close_notify alert was sent and the peer's close_notify alert was received.
1849 }
1850 else
1851 {
1852 // shutdown is not yet finished.
1853 // the close_notify was sent but the peer did not send it back yet.
1854 // SSL_read must be called to do a bidirectional shutdown.
1855 }
1856 }
1857 }
1858
1860 }
1861
1865 virtual void close () noexcept override
1866 {
1869 this->_tlsHandle.reset ();
1870 }
1871
1877 virtual bool waitReadyRead (int timeout = 0) const noexcept override
1878 {
1879 if (this->encrypted () && (SSL_want_read (this->_tlsHandle.get ()) || SSL_want_write (this->_tlsHandle.get ())))
1880 {
1881 return (this->wait (SSL_want_read (this->_tlsHandle.get ()), SSL_want_write (this->_tlsHandle.get ()), timeout) == 0);
1882 }
1883
1885 }
1886
1891 virtual int canRead () const noexcept override
1892 {
1893 if (this->encrypted ())
1894 {
1895 return SSL_pending (this->_tlsHandle.get ());
1896 }
1897
1899 }
1900
1907 virtual int read (char *data, unsigned long maxSize) noexcept override
1908 {
1909 if (this->encrypted ())
1910 {
1911 // read data.
1912 int result = SSL_read (this->_tlsHandle.get (), data, int (maxSize));
1913 if (result < 1)
1914 {
1915 switch (SSL_get_error (this->_tlsHandle.get (), result))
1916 {
1917 case SSL_ERROR_WANT_READ:
1918 case SSL_ERROR_WANT_WRITE:
1919 case SSL_ERROR_WANT_X509_LOOKUP:
1920 // SSL_read want read, want write or want lookup.
1922 break;
1923 case SSL_ERROR_ZERO_RETURN:
1924 // a close notify alert was received.
1925 // we have to answer by sending a close notify alert too.
1927 if (SSL_get_shutdown (this->_tlsHandle.get ()) & SSL_SENT_SHUTDOWN)
1928 {
1930 }
1931 break;
1932 case SSL_ERROR_SYSCALL:
1933 // an error occurred at the socket level.
1934 switch (errno)
1935 {
1936 case 0:
1937 case ECONNRESET:
1938 case EPIPE:
1941 this->_state = State::Disconnected;
1942 break;
1943 default:
1944 lastError = std::make_error_code (static_cast <std::errc> (errno));
1945 break;
1946 }
1947 break;
1948 default:
1949 // SSL protocol error.
1950 #ifdef DEBUG
1951 std::cout << ERR_reason_error_string (ERR_get_error ()) << std::endl;
1952 #endif
1954 break;
1955 }
1956
1957 return -1;
1958 }
1959
1960 return result;
1961 }
1962
1963 return BasicStreamSocket <Protocol>::read (data, maxSize);
1964 }
1965
1971 virtual bool waitReadyWrite (int timeout = 0) const noexcept override
1972 {
1973 if (this->encrypted () && (SSL_want_read (this->_tlsHandle.get ()) || SSL_want_write (this->_tlsHandle.get ())))
1974 {
1975 return (this->wait (SSL_want_read (this->_tlsHandle.get ()), SSL_want_write (this->_tlsHandle.get ()), timeout) == 0);
1976 }
1977
1979 }
1980
1987 virtual int write (const char *data, unsigned long maxSize) noexcept override
1988 {
1989 if (this->encrypted ())
1990 {
1991 // write data.
1992 int result = SSL_write (this->_tlsHandle.get (), data, int (maxSize));
1993 if (result < 1)
1994 {
1995 switch (SSL_get_error (this->_tlsHandle.get (), result))
1996 {
1997 case SSL_ERROR_WANT_READ:
1998 case SSL_ERROR_WANT_WRITE:
1999 case SSL_ERROR_WANT_X509_LOOKUP:
2000 // SSL_write want read, want write or want lookup.
2002 break;
2003 case SSL_ERROR_ZERO_RETURN:
2004 // a close notify alert was received.
2005 // we have to answer by sending a close notify alert too.
2007 if (SSL_get_shutdown (this->_tlsHandle.get ()) & SSL_SENT_SHUTDOWN)
2008 {
2010 }
2011 break;
2012 case SSL_ERROR_SYSCALL:
2013 // an error occurred at the socket level.
2014 switch (errno)
2015 {
2016 case 0:
2017 case ECONNRESET:
2018 case EPIPE:
2021 this->_state = State::Disconnected;
2022 break;
2023 default:
2024 lastError = std::make_error_code (static_cast <std::errc> (errno));
2025 break;
2026 }
2027 break;
2028 default:
2029 // SSL protocol error.
2030 #ifdef DEBUG
2031 std::cout << ERR_reason_error_string (ERR_get_error ()) << std::endl;
2032 #endif
2034 break;
2035 }
2036
2037 return -1;
2038 }
2039
2040 return result;
2041 }
2042
2043 return BasicStreamSocket <Protocol>::write (data, maxSize);
2044 }
2045
2050 virtual bool encrypted () const noexcept override
2051 {
2052 return (this->_tlsState == TlsState::Encrypted);
2053 }
2054
2061 int setCertificate (const std::string& cert, const std::string& key = "")
2062 {
2063 if (((this->_tlsHandle) ? SSL_use_certificate_file (this->_tlsHandle.get (), cert.c_str (), SSL_FILETYPE_PEM)
2064 : SSL_CTX_use_certificate_file (this->_tlsContext.get (), cert.c_str (), SSL_FILETYPE_PEM)) == 0)
2065 {
2066 lastError = make_error_code (Errc::InvalidParam);
2067 return -1;
2068 }
2069
2070 if (key.size ())
2071 {
2072 if (((this->_tlsHandle) ? SSL_use_PrivateKey_file (this->_tlsHandle.get (), key.c_str (), SSL_FILETYPE_PEM)
2073 : SSL_CTX_use_PrivateKey_file (this->_tlsContext.get (), key.c_str (), SSL_FILETYPE_PEM)) == 0)
2074 {
2075 lastError = make_error_code (Errc::InvalidParam);
2076 return -1;
2077 }
2078 }
2079
2080 if (((this->_tlsHandle) ? SSL_check_private_key (this->_tlsHandle.get ())
2081 : SSL_CTX_check_private_key (this->_tlsContext.get ())) == 0)
2082 {
2083 lastError = make_error_code (Errc::InvalidParam);
2084 return -1;
2085 }
2086
2087 return 0;
2088 }
2089
2095 int setCaPath (const std::string& caPath)
2096 {
2097 struct stat st;
2098 if (stat (caPath.c_str (), &st) != 0 || !S_ISDIR (st.st_mode) ||
2099 SSL_CTX_load_verify_locations (this->_tlsContext.get (), nullptr, caPath.c_str ()) == 0)
2100 {
2101 lastError = make_error_code (Errc::InvalidParam);
2102 return -1;
2103 }
2104
2105 return 0;
2106 }
2107
2113 int setCaFile (const std::string& caFile)
2114 {
2115 struct stat st;
2116 if (stat (caFile.c_str (), &st) != 0 || !S_ISREG (st.st_mode) ||
2117 SSL_CTX_load_verify_locations (this->_tlsContext.get (), caFile.c_str (), nullptr) == 0)
2118 {
2119 lastError = make_error_code (Errc::InvalidParam);
2120 return -1;
2121 }
2122
2123 return 0;
2124 }
2125
2131 void setVerify (bool verify, int depth = -1)
2132 {
2133 if (verify == true)
2134 {
2135 SSL_CTX_set_verify (this->_tlsContext.get (), SSL_VERIFY_PEER, verifyWrapper);
2136 SSL_CTX_set_verify_depth (this->_tlsContext.get (), depth);
2137 }
2138 else
2139 {
2140 SSL_CTX_set_verify (this->_tlsContext.get (), SSL_VERIFY_NONE, nullptr);
2141 }
2142 }
2143
2149 int setCipher (const std::string &cipher)
2150 {
2151 if (((this->_tlsHandle) ? SSL_set_cipher_list (this->_tlsHandle.get (), cipher.c_str ())
2152 : SSL_CTX_set_cipher_list (this->_tlsContext.get (), cipher.c_str ())) == 0)
2153 {
2154 lastError = make_error_code (Errc::InvalidParam);
2155 return -1;
2156 }
2157
2158 return 0;
2159 }
2160
2166 int setCipher_1_3 (const std::string &cipher)
2167 {
2168 if (((this->_tlsHandle) ? SSL_set_ciphersuites (this->_tlsHandle.get (), cipher.c_str ())
2169 : SSL_CTX_set_ciphersuites (this->_tlsContext.get (), cipher.c_str ())) == 0)
2170 {
2171 lastError = make_error_code (Errc::InvalidParam);
2172 return -1;
2173 }
2174
2175 return 0;
2176 }
2177
2178 protected:
2187
2193 {
2194 // start the SSL handshake.
2195 int result = SSL_do_handshake (this->_tlsHandle.get ());
2196 if (result < 1)
2197 {
2198 switch (SSL_get_error (this->_tlsHandle.get (), result))
2199 {
2200 case SSL_ERROR_WANT_READ:
2201 case SSL_ERROR_WANT_WRITE:
2202 case SSL_ERROR_WANT_X509_LOOKUP:
2203 // SSL_do_handshake want read or want write.
2205 break;
2206 case SSL_ERROR_ZERO_RETURN:
2207 // a close notify alert was received.
2208 // we have to answer by sending a close notify alert too.
2210 break;
2211 case SSL_ERROR_SYSCALL:
2212 // an error occurred at the socket level.
2213 switch (errno)
2214 {
2215 case 0:
2216 case ECONNRESET:
2217 case EPIPE:
2219 this->_state = State::Disconnected;
2220 break;
2221 default:
2222 lastError = std::make_error_code (static_cast <std::errc> (errno));
2223 break;
2224 }
2225 break;
2226 default:
2227 // SSL protocol error.
2228 #ifdef DEBUG
2229 std::cout << ERR_reason_error_string (ERR_get_error ()) << std::endl;
2230 #endif
2232 break;
2233 }
2234
2235 return -1;
2236 }
2237
2239
2240 return 0;
2241 }
2242
2249 static void infoWrapper (const SSL *ssl, int where, int ret)
2250 {
2251 assert (ssl);
2252 static_cast <BasicTlsSocket <Protocol>*> (SSL_get_app_data (ssl))->infoCallback (where, ret);
2253 }
2254
2260 void infoCallback (int where, int ret) const
2261 {
2262 if (where & SSL_CB_ALERT)
2263 {
2264 std::cout << "SSL/TLS Alert ";
2265 (where & SSL_CB_READ) ? std::cout << "[read] " : std::cout << "[write] ";
2266 std::cout << SSL_alert_type_string_long (ret) << ":";
2267 std::cout << SSL_alert_desc_string_long (ret);
2268 std::cout << std::endl;
2269 }
2270 else if (where & SSL_CB_LOOP)
2271 {
2272 std::cout << "SSL/TLS State ";
2273 (SSL_in_connect_init (this->_tlsHandle.get ())) ? std::cout << "[connect] " : (SSL_in_accept_init (this->_tlsHandle.get ())) ? std::cout << "[accept] " : std::cout << "[undefined] ";
2274 std::cout << SSL_state_string_long (this->_tlsHandle.get ());
2275 std::cout << std::endl;
2276 }
2277 else if (where & SSL_CB_HANDSHAKE_START)
2278 {
2279 std::cout << "SSL/TLS Handshake [Start] "<< SSL_state_string_long (this->_tlsHandle.get ()) << std::endl;
2280 }
2281 else if (where & SSL_CB_HANDSHAKE_DONE)
2282 {
2283 std::cout << "SSL/TLS Handshake [Done] "<< SSL_state_string_long (this->_tlsHandle.get ()) << std::endl;
2284 std::cout << SSL_CTX_sess_number (this->_tlsContext.get ()) << " items in the session cache"<< std::endl;
2285 std::cout << SSL_CTX_sess_connect (this->_tlsContext.get ()) << " client connects"<< std::endl;
2286 std::cout << SSL_CTX_sess_connect_good (this->_tlsContext.get ()) << " client connects that finished"<< std::endl;
2287 std::cout << SSL_CTX_sess_connect_renegotiate (this->_tlsContext.get ()) << " client renegotiations requested"<< std::endl;
2288 std::cout << SSL_CTX_sess_accept (this->_tlsContext.get ()) << " server connects"<< std::endl;
2289 std::cout << SSL_CTX_sess_accept_good (this->_tlsContext.get ()) << " server connects that finished"<< std::endl;
2290 std::cout << SSL_CTX_sess_accept_renegotiate (this->_tlsContext.get ()) << " server renegotiations requested"<< std::endl;
2291 std::cout << SSL_CTX_sess_hits (this->_tlsContext.get ()) << " session cache hits"<< std::endl;
2292 std::cout << SSL_CTX_sess_cb_hits (this->_tlsContext.get ()) << " external session cache hits"<< std::endl;
2293 std::cout << SSL_CTX_sess_misses (this->_tlsContext.get ()) << " session cache misses"<< std::endl;
2294 std::cout << SSL_CTX_sess_timeouts (this->_tlsContext.get ()) << " session cache timeouts"<< std::endl;
2295 std::cout << "negotiated " << SSL_get_cipher (this->_tlsHandle.get ()) << " cipher suite" << std::endl;
2296 }
2297 }
2298
2305 static int verifyWrapper (int preverified, X509_STORE_CTX *context)
2306 {
2307 SSL* ssl = static_cast <SSL*> (X509_STORE_CTX_get_ex_data (context, SSL_get_ex_data_X509_STORE_CTX_idx ()));
2308
2309 assert (ssl);
2310 return static_cast <BasicTlsSocket <Protocol>*> (SSL_get_app_data (ssl))->verifyCallback (preverified, context);
2311 }
2312
2319 int verifyCallback (int preverified, X509_STORE_CTX *context) const
2320 {
2321 int maxDepth = SSL_get_verify_depth (this->_tlsHandle.get ());
2322 int dpth = X509_STORE_CTX_get_error_depth (context);
2323
2324 #ifdef DEBUG
2325 std::cout << "verification started at depth="<< dpth << std::endl;
2326 #endif
2327
2328 // catch a too long certificate chain.
2329 if ((maxDepth >= 0) && (dpth > maxDepth))
2330 {
2331 preverified = 0;
2332 X509_STORE_CTX_set_error (context, X509_V_ERR_CERT_CHAIN_TOO_LONG);
2333 }
2334
2335 if (!preverified)
2336 {
2337 #ifdef DEBUG
2338 std::cout << "verification failed at depth=" << dpth << " - " << X509_verify_cert_error_string (X509_STORE_CTX_get_error (context)) << std::endl;
2339 #endif
2340 return 0;
2341 }
2342
2343 // check the certificate host name.
2344 if (!verifyCert (context))
2345 {
2346 #ifdef DEBUG
2347 std::cout << "rejected by CERT at depth=" << dpth << std::endl;
2348 #endif
2349 return 0;
2350 }
2351
2352 // check the revocation list.
2353 /*if (!verifyCrl (context))
2354 {
2355 #ifdef DEBUG
2356 std::cout << "rejected by CRL at depth=" << dpth << std::endl;
2357 #endif
2358 return 0;
2359 }*/
2360
2361 // check ocsp.
2362 /*if (!verifyOcsp (context))
2363 {
2364 #ifdef DEBUG
2365 std::cout << "rejected by OCSP at depth=" << dpth << std::endl;
2366 #endif
2367 return 0;
2368 }*/
2369
2370 #ifdef DEBUG
2371 std::cout << "certificate accepted at depth=" << dpth << std::endl;
2372 #endif
2373
2374 return 1;
2375 }
2376
2382 int verifyCert (X509_STORE_CTX *context) const
2383 {
2384 int depth = X509_STORE_CTX_get_error_depth (context);
2385 X509* cert = X509_STORE_CTX_get_current_cert (context);
2386
2387 char buf[256];
2388 X509_NAME_oneline (X509_get_subject_name (cert), buf, sizeof (buf));
2389 #ifdef DEBUG
2390 std::cout << "subject=" << buf << std::endl;
2391 #endif
2392
2393 // check the certificate host name
2394 if (depth == 0)
2395 {
2396 // confirm a match between the hostname and the hostnames listed in the certificate.
2397 if (!checkHostName (cert))
2398 {
2399 #ifdef DEBUG
2400 std::cout << "no match for hostname in the certificate" << std::endl;
2401 #endif
2402 return 0;
2403 }
2404 }
2405
2406 return 1;
2407 }
2408
2414 bool checkHostName (X509 *certificate) const
2415 {
2416 bool match = false;
2417
2418 // get alternative names.
2419 join::StackOfGeneralNamePtr altnames (reinterpret_cast <STACK_OF (GENERAL_NAME)*> (X509_get_ext_d2i (certificate, NID_subject_alt_name, 0, 0)));
2420 if (altnames)
2421 {
2422 for (int i = 0; (i < sk_GENERAL_NAME_num (altnames.get ())) && !match; ++i)
2423 {
2424 // get a handle to alternative name.
2425 GENERAL_NAME *current_name = sk_GENERAL_NAME_value (altnames.get (), i);
2426
2427 if (current_name->type == GEN_DNS)
2428 {
2429 // get data and length.
2430 const char *host = reinterpret_cast <const char *> (ASN1_STRING_get0_data (current_name->d.ia5));
2431 size_t len = size_t (ASN1_STRING_length (current_name->d.ia5));
2432 std::string pattern (host, host + len), serverName (this->_remote.hostname ());
2433
2434 // strip off trailing dots.
2435 if (pattern.back () == '.')
2436 {
2437 pattern.pop_back ();
2438 }
2439
2440 if (serverName.back () == '.')
2441 {
2442 serverName.pop_back ();
2443 }
2444
2445 // compare to pattern.
2446 if (fnmatch (pattern.c_str (), serverName.c_str (), 0) == 0)
2447 {
2448 // an alternative name matched the server hostname.
2449 match = true;
2450 }
2451 }
2452 }
2453 }
2454
2455 return match;
2456 }
2457
2463 /*int verifyCrl ([[maybe_unused]]X509_STORE_CTX *context) const
2464 {
2465 return 1;
2466 }*/
2467
2473 /*int verifyOcsp ([[maybe_unused]]X509_STORE_CTX *context) const
2474 {
2475 return 1;
2476 }*/
2477
2480
2483
2486
2488 friend class BasicTlsAcceptor <Protocol>;
2489 };
2490
2497 template <class Protocol>
2498 constexpr bool operator< (const BasicTlsSocket <Protocol>& a, const BasicTlsSocket <Protocol>& b) noexcept
2499 {
2500 return a.handle () < b.handle ();
2501 }
2502}
2503
2504namespace std
2505{
2507 template <> struct is_error_condition_enum <join::TlsErrc> : public true_type {};
2508}
2509
2510#endif
basic datagram socket class.
Definition socket.hpp:645
virtual ~BasicDatagramSocket()=default
Destroy the instance.
typename BasicSocket< Protocol >::Mode Mode
Definition socket.hpp:648
BasicDatagramSocket(const BasicDatagramSocket &other)=delete
Copy constructor.
BasicDatagramSocket(Mode mode, int ttl=60)
Create instance specifying the mode.
Definition socket.hpp:665
virtual int connect(const Endpoint &endpoint)
make a connection to the given endpoint.
Definition socket.hpp:798
std::unique_ptr< BasicDatagramSocket< Protocol > > Ptr
Definition socket.hpp:647
virtual int bindToDevice(const std::string &device) noexcept
assigns the specified device to the socket.
Definition socket.hpp:764
virtual bool connected() noexcept
check if the socket is connected.
Definition socket.hpp:1062
virtual int setOption(Option option, int value) noexcept override
set the given option to the given value.
Definition socket.hpp:958
Endpoint _remote
remote endpoint.
Definition socket.hpp:1116
virtual int read(char *data, unsigned long maxSize) noexcept override
read data.
Definition socket.hpp:876
int _ttl
packet time to live.
Definition socket.hpp:1119
virtual int write(const char *data, unsigned long maxSize) noexcept override
write data.
Definition socket.hpp:923
virtual void close() noexcept override
close the socket handle.
Definition socket.hpp:864
typename Protocol::Endpoint Endpoint
Definition socket.hpp:651
int ttl() const
returns the Time-To-Live value.
Definition socket.hpp:1109
BasicDatagramSocket(int ttl=60)
Default constructor.
Definition socket.hpp:656
int mtu() const
get socket mtu.
Definition socket.hpp:1071
BasicDatagramSocket & operator=(const BasicDatagramSocket &other)=delete
Copy assignment operator.
virtual int readFrom(char *data, unsigned long maxSize, Endpoint *endpoint=nullptr) noexcept
read data on the socket.
Definition socket.hpp:888
virtual int writeTo(const char *data, unsigned long maxSize, const Endpoint &endpoint) noexcept
write data on the socket.
Definition socket.hpp:935
typename BasicSocket< Protocol >::Option Option
Definition socket.hpp:649
virtual int disconnect()
shutdown the connection.
Definition socket.hpp:835
virtual int open(const Protocol &protocol=Protocol()) noexcept override
open socket using the given protocol.
Definition socket.hpp:723
BasicDatagramSocket(BasicDatagramSocket &&other)
Move constructor.
Definition socket.hpp:688
typename BasicSocket< Protocol >::State State
Definition socket.hpp:650
const Endpoint & remoteEndpoint() const
determine the remote endpoint associated with this socket.
Definition socket.hpp:1053
basic socket class.
Definition socket.hpp:60
bool opened() const noexcept
check if the socket is opened.
Definition socket.hpp:489
BasicSocket & operator=(const BasicSocket &other)=delete
copy assignment operator.
virtual int open(const Protocol &protocol=Protocol()) noexcept
open socket using the given protocol.
Definition socket.hpp:194
static uint16_t checksum(const uint16_t *data, size_t len, uint16_t current=0)
get standard 1s complement checksum.
Definition socket.hpp:546
void setMode(Mode mode) noexcept
set the socket to the non-blocking or blocking mode.
Definition socket.hpp:380
State
socket states.
Definition socket.hpp:102
@ Disconnected
Definition socket.hpp:106
@ Connecting
Definition socket.hpp:103
@ Disconnecting
Definition socket.hpp:105
@ Closed
Definition socket.hpp:107
@ Connected
Definition socket.hpp:104
Protocol _protocol
protocol.
Definition socket.hpp:625
virtual int setOption(Option option, int value) noexcept
set the given option to the given value.
Definition socket.hpp:407
int handle() const noexcept override
get socket native handle.
Definition socket.hpp:534
std::unique_ptr< BasicSocket< Protocol > > Ptr
Definition socket.hpp:62
Mode _mode
socket mode.
Definition socket.hpp:619
virtual void close() noexcept
close the socket.
Definition socket.hpp:223
int family() const noexcept
get socket address family.
Definition socket.hpp:507
Mode
socket modes.
Definition socket.hpp:69
@ Blocking
Definition socket.hpp:70
@ NonBlocking
Definition socket.hpp:71
virtual bool waitReadyRead(int timeout=0) const noexcept
block until new data is available for reading.
Definition socket.hpp:294
virtual ~BasicSocket()
destroy the socket instance.
Definition socket.hpp:181
virtual int bind(const Endpoint &endpoint) noexcept
assigns the specified endpoint to the socket.
Definition socket.hpp:238
Endpoint localEndpoint() const
determine the local endpoint associated with this socket.
Definition socket.hpp:472
BasicSocket()
default constructor.
Definition socket.hpp:113
int type() const noexcept
get the protocol communication semantic.
Definition socket.hpp:516
State _state
socket state.
Definition socket.hpp:616
typename Protocol::Endpoint Endpoint
Definition socket.hpp:63
virtual bool encrypted() const noexcept
check if the socket is secure.
Definition socket.hpp:498
Option
socket options.
Definition socket.hpp:78
@ MulticastTtl
Definition socket.hpp:92
@ SndBuffer
Definition socket.hpp:84
@ Ttl
Definition socket.hpp:90
@ Broadcast
Definition socket.hpp:89
@ RcvError
Definition socket.hpp:94
@ ReusePort
Definition socket.hpp:88
@ MulticastLoop
Definition socket.hpp:91
@ KeepAlive
Definition socket.hpp:80
@ ReuseAddr
Definition socket.hpp:87
@ KeepCount
Definition socket.hpp:83
@ KeepIntvl
Definition socket.hpp:82
@ AuxData
Definition socket.hpp:95
@ PathMtuDiscover
Definition socket.hpp:93
@ NoDelay
Definition socket.hpp:79
@ TimeStamp
Definition socket.hpp:86
@ KeepIdle
Definition socket.hpp:81
@ RcvBuffer
Definition socket.hpp:85
virtual bool waitReadyWrite(int timeout=0) const noexcept
block until at least one byte can be written.
Definition socket.hpp:341
BasicSocket(BasicSocket &&other)
move constructor.
Definition socket.hpp:144
virtual int read(char *data, unsigned long maxSize) noexcept
read data.
Definition socket.hpp:305
virtual int write(const char *data, unsigned long maxSize) noexcept
write data.
Definition socket.hpp:352
BasicSocket(Mode mode)
create socket instance specifying the mode.
Definition socket.hpp:122
int wait(bool wantRead, bool wantWrite, int timeout) const noexcept
wait for the socket handle to become ready.
Definition socket.hpp:579
int protocol() const noexcept
get socket protocol.
Definition socket.hpp:525
int _handle
socket handle.
Definition socket.hpp:622
virtual int canRead() const noexcept
get the number of readable bytes.
Definition socket.hpp:275
BasicSocket(const BasicSocket &other)=delete
copy constructor.
basic stream acceptor class.
Definition protocol.hpp:45
basic stream socket class.
Definition socket.hpp:1139
std::unique_ptr< BasicStreamSocket< Protocol > > Ptr
Definition socket.hpp:1141
virtual ~BasicStreamSocket()=default
destroy the instance.
typename BasicDatagramSocket< Protocol >::Option Option
Definition socket.hpp:1143
typename BasicDatagramSocket< Protocol >::Mode Mode
Definition socket.hpp:1142
typename BasicDatagramSocket< Protocol >::State State
Definition socket.hpp:1144
virtual int setOption(Option option, int value) noexcept override
set the given option to the given value.
Definition socket.hpp:1393
virtual bool connecting() const noexcept
check if the socket is connecting.
Definition socket.hpp:1443
BasicStreamSocket(const BasicStreamSocket &other)=delete
copy constructor.
BasicStreamSocket(BasicStreamSocket &&other)
move constructor.
Definition socket.hpp:1181
virtual bool waitConnected(int timeout=0)
block until connected.
Definition socket.hpp:1208
int readExactly(std::string &data, unsigned long size, int timeout=0)
read data until size is reached or an error occurred.
Definition socket.hpp:1350
int writeExactly(const char *data, unsigned long size, int timeout=0)
write data until size is reached or an error occurred.
Definition socket.hpp:1363
BasicStreamSocket(Mode mode)
create instance specifying the mode.
Definition socket.hpp:1159
typename Protocol::Endpoint Endpoint
Definition socket.hpp:1145
BasicStreamSocket & operator=(const BasicStreamSocket &other)=delete
copy assignment operator.
virtual bool connected() noexcept override
check if the socket is connected.
Definition socket.hpp:1452
virtual bool waitDisconnected(int timeout=0)
wait until the connection as been shut down.
Definition socket.hpp:1275
virtual int disconnect() override
shutdown the connection.
Definition socket.hpp:1233
BasicStreamSocket()
default constructor.
Definition socket.hpp:1150
int readExactly(char *data, unsigned long size, int timeout=0)
read data until size is reached or an error occurred.
Definition socket.hpp:1319
basic TLS acceptor class.
Definition protocol.hpp:46
basic TLS socket class.
Definition socket.hpp:1547
BasicTlsSocket()
default constructor.
Definition socket.hpp:1558
BasicTlsSocket & operator=(const BasicTlsSocket &other)=delete
copy assignment operator.
int setCaFile(const std::string &caFile)
set the location of the trusted CA certificate file.
Definition socket.hpp:2113
std::unique_ptr< BasicTlsSocket< Protocol > > Ptr
Definition socket.hpp:1549
virtual int canRead() const noexcept override
get the number of readable bytes.
Definition socket.hpp:1891
int startEncryption()
start socket encryption (perform TLS handshake).
Definition socket.hpp:1707
bool waitEncrypted(int timeout=0)
wait until TLS handshake is performed or timeout occur (non blocking socket).
Definition socket.hpp:1758
virtual bool encrypted() const noexcept override
check if the socket is secure.
Definition socket.hpp:2050
BasicTlsSocket(Mode mode, join::SslCtxPtr tlsContext)
Create socket instance specifying the socket mode and TLS context.
Definition socket.hpp:1613
join::SslCtxPtr _tlsContext
verify certificate revocation using CRL.
Definition socket.hpp:2479
virtual int read(char *data, unsigned long maxSize) noexcept override
read data on the socket.
Definition socket.hpp:1907
void setVerify(bool verify, int depth=-1)
Enable/Disable the verification of the peer certificate.
Definition socket.hpp:2131
join::SslPtr _tlsHandle
TLS handle.
Definition socket.hpp:2482
int verifyCallback(int preverified, X509_STORE_CTX *context) const
trusted CA certificates verification callback.
Definition socket.hpp:2319
virtual bool waitReadyRead(int timeout=0) const noexcept override
block until new data is available for reading.
Definition socket.hpp:1877
int verifyCert(X509_STORE_CTX *context) const
verify certificate validity.
Definition socket.hpp:2382
void infoCallback(int where, int ret) const
state information callback.
Definition socket.hpp:2260
BasicTlsSocket(Mode mode)
create instance specifying the mode.
Definition socket.hpp:1567
TlsState
TLS state.
Definition socket.hpp:2183
@ Encrypted
Definition socket.hpp:2184
@ NonEncrypted
Definition socket.hpp:2185
TlsState _tlsState
TLS state.
Definition socket.hpp:2485
bool checkHostName(X509 *certificate) const
confirm a match between the hostname contacted and the hostnames listed in the certificate.
Definition socket.hpp:2414
typename BasicStreamSocket< Protocol >::Mode Mode
Definition socket.hpp:1550
virtual int disconnect() override
shutdown the connection.
Definition socket.hpp:1798
int setCertificate(const std::string &cert, const std::string &key="")
set the certificate and the private key.
Definition socket.hpp:2061
virtual bool waitReadyWrite(int timeout=0) const noexcept override
block until until at least one byte can be written on the socket.
Definition socket.hpp:1971
int setCipher(const std::string &cipher)
set the cipher list (TLSv1.2 and below).
Definition socket.hpp:2149
typename Protocol::Endpoint Endpoint
Definition socket.hpp:1553
virtual void close() noexcept override
close the socket handle.
Definition socket.hpp:1865
int setCipher_1_3(const std::string &cipher)
set the cipher list (TLSv1.3).
Definition socket.hpp:2166
virtual ~BasicTlsSocket()=default
destroy the instance.
int setCaPath(const std::string &caPath)
set the location of the trusted CA certificates.
Definition socket.hpp:2095
typename BasicStreamSocket< Protocol >::State State
Definition socket.hpp:1552
BasicTlsSocket(const BasicTlsSocket &other)=delete
copy constructor.
int startHandshake()
Start SSL handshake.
Definition socket.hpp:2192
typename BasicStreamSocket< Protocol >::Option Option
Definition socket.hpp:1551
static void infoWrapper(const SSL *ssl, int where, int ret)
c style callback wrapper for the state information callback.
Definition socket.hpp:2249
virtual int write(const char *data, unsigned long maxSize) noexcept override
write data on the socket.
Definition socket.hpp:1987
BasicTlsSocket(join::SslCtxPtr tlsContext)
create instance specifying TLS context.
Definition socket.hpp:1603
static int verifyWrapper(int preverified, X509_STORE_CTX *context)
c style callback wrapper for the Trusted CA certificates verification callback.
Definition socket.hpp:2305
BasicTlsSocket(BasicTlsSocket &&other)
move constructor.
Definition socket.hpp:1640
int connectEncrypted(const Endpoint &endpoint)
make an encrypted connection to the given endpoint.
Definition socket.hpp:1687
Event handler interface class.
Definition reactor.hpp:44
TLS error category.
Definition socket.hpp:1506
virtual std::string message(int code) const
translate digest error code to human readable error string.
Definition socket.cpp:44
virtual const char * name() const noexcept
get digest error category name.
Definition socket.cpp:35
const std::string key(65, 'a')
key.
Definition acceptor.hpp:32
bool operator<(const BasicUnixEndpoint< Protocol > &a, const BasicUnixEndpoint< Protocol > &b) noexcept
compare if endpoint is lower.
Definition endpoint.hpp:207
std::error_code make_error_code(join::Errc code)
Create an std::error_code object.
Definition error.cpp:154
std::unique_ptr< SSL, SslDelete > SslPtr
Definition openssl.hpp:225
const std::string defaultCipher_1_3
Definition openssl.cpp:39
std::unique_ptr< SSL_CTX, SslCtxDelete > SslCtxPtr
Definition openssl.hpp:240
TlsErrc
TLS error codes.
Definition socket.hpp:1497
const std::error_category & getTlsCategory()
get error category.
Definition socket.cpp:61
std::unique_ptr< STACK_OF(GENERAL_NAME), StackOfGeneralNameDelete > StackOfGeneralNamePtr
Definition openssl.hpp:210
const std::string defaultCipher
Definition openssl.cpp:36
std::error_condition make_error_condition(join::Errc code)
Create an std::error_condition object.
Definition error.cpp:163
Definition error.hpp:106