join 1.0
lightweight network framework library
Loading...
Searching...
No Matches
socket.hpp
Go to the documentation of this file.
1
25#ifndef JOIN_CORE_SOCKET_HPP
26#define JOIN_CORE_SOCKET_HPP
27
28// libjoin.
29#include <join/protocol.hpp>
30#include <join/endpoint.hpp>
31#include <join/openssl.hpp>
32#include <join/reactor.hpp>
33#include <join/utils.hpp>
34#include <join/error.hpp>
35
36// Libraries.
37#include <openssl/err.h>
38
39// C++.
40#include <type_traits>
41#include <iostream>
42
43// C.
44#include <netinet/tcp.h>
45#include <linux/icmp.h>
46#include <sys/ioctl.h>
47#include <sys/stat.h>
48#include <fnmatch.h>
49#include <cassert>
50#include <fcntl.h>
51#include <poll.h>
52
53namespace join
54{
58 template <class Protocol>
60 {
61 public:
62 using Ptr = std::unique_ptr<BasicSocket<Protocol>>;
63 using Endpoint = typename Protocol::Endpoint;
64
68 enum Mode
69 {
72 };
73
97
109
115 {
116 }
117
123 : _mode (mode)
124 {
125 }
126
131 BasicSocket (const BasicSocket& other) = delete;
132
138 BasicSocket& operator= (const BasicSocket& other) = delete;
139
145 : _state (other._state)
146 , _mode (other._mode)
147 , _handle (other._handle)
148 , _protocol (other._protocol)
149 {
150 other._state = State::Closed;
151 other._mode = Mode::NonBlocking;
152 other._handle = -1;
153 other._protocol = Protocol ();
154 }
155
162 {
163 this->close ();
164
165 this->_state = other._state;
166 this->_mode = other._mode;
167 this->_handle = other._handle;
168 this->_protocol = other._protocol;
169
170 other._state = State::Closed;
171 other._mode = Mode::NonBlocking;
172 other._handle = -1;
173 other._protocol = Protocol ();
174
175 return *this;
176 }
177
181 virtual ~BasicSocket ()
182 {
183 if (this->_handle != -1)
184 {
185 ::close (this->_handle);
186 }
187 }
188
194 virtual int open (const Protocol& protocol = Protocol ()) noexcept
195 {
196 if (this->_state != State::Closed)
197 {
198 lastError = make_error_code (Errc::InUse);
199 return -1;
200 }
201
202 if (this->_mode == Mode::NonBlocking)
203 this->_handle = ::socket (protocol.family (), protocol.type () | SOCK_NONBLOCK, protocol.protocol ());
204 else
205 this->_handle = ::socket (protocol.family (), protocol.type (), protocol.protocol ());
206
207 if (this->_handle == -1)
208 {
209 lastError = std::error_code (errno, std::generic_category ());
210 this->close ();
211 return -1;
212 }
213
215 this->_protocol = protocol;
216
217 return 0;
218 }
219
223 virtual void close () noexcept
224 {
225 if (this->_state != State::Closed)
226 {
227 ::close (this->_handle);
228 this->_state = State::Closed;
229 this->_handle = -1;
230 }
231 }
232
238 virtual int bind (const Endpoint& endpoint) noexcept
239 {
240 if ((this->_state == State::Closed) && (this->open (endpoint.protocol ()) == -1))
241 {
242 return -1;
243 }
244
245 if (endpoint.protocol ().family () == AF_PACKET)
246 {
247 if (reinterpret_cast<const struct sockaddr_ll*> (endpoint.addr ())->sll_ifindex == 0)
248 {
249 lastError = std::make_error_code (std::errc::no_such_device);
250 return -1;
251 }
252 }
253 else if ((endpoint.protocol ().family () == AF_INET6) || (endpoint.protocol ().family () == AF_INET))
254 {
255 this->setOption (Option::ReuseAddr, 1);
256 }
257 else if (endpoint.protocol ().family () == AF_UNIX)
258 {
259 ::unlink (endpoint.device ().c_str ());
260 }
261
262 if (::bind (this->_handle, endpoint.addr (), endpoint.length ()) == -1)
263 {
264 lastError = std::error_code (errno, std::generic_category ());
265 return -1;
266 }
267
268 return 0;
269 }
270
275 virtual int canRead () const noexcept
276 {
277 int available = 0;
278
279 // check if data can be read in the socket internal buffer.
280 if (::ioctl (this->_handle, FIONREAD, &available) == -1)
281 {
282 lastError = std::error_code (errno, std::generic_category ());
283 return -1;
284 }
285
286 return available;
287 }
288
294 virtual bool waitReadyRead (int timeout = 0) const noexcept
295 {
296 return (this->wait (true, false, timeout) == 0);
297 }
298
305 virtual int read (char* data, unsigned long maxSize) noexcept
306 {
307 struct iovec iov;
308 iov.iov_base = data;
309 iov.iov_len = maxSize;
310
311 struct msghdr message;
312 message.msg_name = nullptr;
313 message.msg_namelen = 0;
314 message.msg_iov = &iov;
315 message.msg_iovlen = 1;
316 message.msg_control = nullptr;
317 message.msg_controllen = 0;
318
319 int size = ::recvmsg (this->_handle, &message, 0);
320 if (size < 1)
321 {
322 if (size == -1)
323 {
324 lastError = std::error_code (errno, std::generic_category ());
325 }
326 else
327 {
329 }
330 return -1;
331 }
332
333 return size;
334 }
335
341 virtual bool waitReadyWrite (int timeout = 0) const noexcept
342 {
343 return (this->wait (false, true, timeout) == 0);
344 }
345
352 virtual int write (const char* data, unsigned long maxSize) noexcept
353 {
354 struct iovec iov;
355 iov.iov_base = const_cast<char*> (data);
356 iov.iov_len = maxSize;
357
358 struct msghdr message;
359 message.msg_name = nullptr;
360 message.msg_namelen = 0;
361 message.msg_iov = &iov;
362 message.msg_iovlen = 1;
363 message.msg_control = nullptr;
364 message.msg_controllen = 0;
365
366 int result = ::sendmsg (this->_handle, &message, 0);
367 if (result == -1)
368 {
369 lastError = std::error_code (errno, std::generic_category ());
370 return -1;
371 }
372
373 return result;
374 }
375
380 void setMode (Mode mode) noexcept
381 {
382 this->_mode = mode;
383
384 if (this->_state != State::Closed)
385 {
386 int flags = ::fcntl (this->_handle, F_GETFL, 0);
387
388 if (this->_mode == Mode::NonBlocking)
389 {
390 flags = flags | O_NONBLOCK;
391 }
392 else
393 {
394 flags = flags & ~O_NONBLOCK;
395 }
396
397 ::fcntl (this->_handle, F_SETFL, flags);
398 }
399 }
400
407 virtual int setOption (Option option, int value) noexcept
408 {
409 int optlevel, optname;
410
411 switch (option)
412 {
414 optlevel = SOL_SOCKET;
415 optname = SO_KEEPALIVE;
416 break;
417
419 optlevel = SOL_SOCKET;
420 optname = SO_SNDBUF;
421 break;
422
424 optlevel = SOL_SOCKET;
425 optname = SO_RCVBUF;
426 break;
427
429 optlevel = SOL_SOCKET;
430 optname = SO_TIMESTAMP;
431 break;
432
434 optlevel = SOL_SOCKET;
435 optname = SO_REUSEADDR;
436 break;
437
439 optlevel = SOL_SOCKET;
440 optname = SO_REUSEPORT;
441 break;
442
444 optlevel = SOL_SOCKET;
445 optname = SO_BROADCAST;
446 break;
447
448 case Option::AuxData:
449 optlevel = SOL_PACKET;
450 optname = PACKET_AUXDATA;
451 break;
452
453 default:
455 return -1;
456 }
457
458 int result = ::setsockopt (this->_handle, optlevel, optname, &value, sizeof (value));
459 if (result == -1)
460 {
461 lastError = std::error_code (errno, std::generic_category ());
462 return -1;
463 }
464
465 return 0;
466 }
467
472 Endpoint localEndpoint () const noexcept
473 {
474 struct sockaddr_storage sa;
475 socklen_t sa_len = sizeof (struct sockaddr_storage);
476
477 if (::getsockname (this->_handle, reinterpret_cast<struct sockaddr*> (&sa), &sa_len) == -1)
478 {
479 return {};
480 }
481
482 return Endpoint (reinterpret_cast<struct sockaddr*> (&sa), sa_len);
483 }
484
489 bool opened () const noexcept
490 {
491 return (this->_state != State::Closed);
492 }
493
498 virtual bool encrypted () const noexcept
499 {
500 return false;
501 }
502
507 int family () const noexcept
508 {
509 return this->_protocol.family ();
510 }
511
516 int type () const noexcept
517 {
518 return this->_protocol.type ();
519 }
520
525 int protocol () const noexcept
526 {
527 return this->_protocol.protocol ();
528 }
529
534 int handle () const noexcept
535 {
536 return this->_handle;
537 }
538
546 static uint16_t checksum (const uint16_t* data, size_t len, uint16_t current = 0)
547 {
548 uint32_t sum = current;
549
550 while (len > 1)
551 {
552 sum += *data++;
553 len -= 2;
554 }
555
556 if (len == 1)
557 {
558#if __BYTE_ORDER == __LITTLE_ENDIAN
559 sum += *reinterpret_cast<const uint8_t*> (data);
560#else
561 sum += *reinterpret_cast<const uint8_t*> (data) << 8;
562#endif
563 }
564
565 sum = (sum >> 16) + (sum & 0xffff);
566 sum += (sum >> 16);
567
568 return static_cast<uint16_t> (~sum);
569 }
570
571 protected:
579 int wait (bool wantRead, bool wantWrite, int timeout) const noexcept
580 {
581 struct pollfd handle = {.fd = this->_handle, .events = 0, .revents = 0};
582
583 if (wantRead)
584 {
585 handle.events |= POLLIN;
586 }
587
588 if (wantWrite)
589 {
590 handle.events |= POLLOUT;
591 }
592
593 int nset = (handle.fd > -1) ? ::poll (&handle, 1, timeout == 0 ? -1 : timeout) : -1;
594 if (nset != 1)
595 {
596 if (nset == -1)
597 {
598 if (handle.fd == -1)
599 {
600 errno = EBADF;
601 }
602 lastError = std::error_code (errno, std::generic_category ());
603 }
604 else
605 {
606 lastError = make_error_code (Errc::TimedOut);
607 }
608
609 return -1;
610 }
611
612 return 0;
613 }
614
617
620
622 int _handle = -1;
623
625 Protocol _protocol;
626 };
627
634 template <class Protocol>
635 constexpr bool operator< (const BasicSocket<Protocol>& a, const BasicSocket<Protocol>& b) noexcept
636 {
637 return a.handle () < b.handle ();
638 }
639
643 template <class Protocol>
644 class BasicDatagramSocket : public BasicSocket<Protocol>
645 {
646 public:
647 using Ptr = std::unique_ptr<BasicDatagramSocket<Protocol>>;
651 using Endpoint = typename Protocol::Endpoint;
652
660
665 BasicDatagramSocket (Mode mode, int ttl = 60)
666 : BasicSocket<Protocol> (mode)
667 , _ttl (ttl)
668 {
669 }
670
676
683
689 : BasicSocket<Protocol> (std::move (other))
690 , _remote (std::move (other._remote))
691 , _ttl (other._ttl)
692 {
693 other._ttl = 60;
694 }
695
702 {
703 BasicSocket<Protocol>::operator= (std::move (other));
704
705 _remote = std::move (other._remote);
706 _ttl = other._ttl;
707
708 other._ttl = 60;
709
710 return *this;
711 }
712
716 virtual ~BasicDatagramSocket () = default;
717
723 virtual int open (const Protocol& protocol = Protocol ()) noexcept override
724 {
726 if (result == -1)
727 {
728 return -1;
729 }
730
731 int off = 0;
732
733 if ((protocol.protocol () == IPPROTO_UDP) || (protocol.protocol () == IPPROTO_TCP))
734 {
735 if ((protocol.family () == AF_INET6) &&
736 (::setsockopt (this->_handle, IPPROTO_IPV6, IPV6_V6ONLY, &off, sizeof (off)) == -1))
737 {
738 lastError = std::error_code (errno, std::generic_category ());
739 this->close ();
740 return -1;
741 }
742 }
743
744 if ((protocol.protocol () == IPPROTO_ICMPV6) || (protocol.protocol () == IPPROTO_ICMP))
745 {
746 if ((protocol.family () == AF_INET) &&
747 (::setsockopt (this->_handle, IPPROTO_IP, IP_HDRINCL, &off, sizeof (off)) == -1))
748 {
749 lastError = std::error_code (errno, std::generic_category ());
750 this->close ();
751 return -1;
752 }
753
754 this->setOption (Option::MulticastTtl, this->_ttl);
755 this->setOption (Option::Ttl, this->_ttl);
756 }
757
758 return 0;
759 }
760
766 virtual int bindToDevice (const std::string& device) noexcept
767 {
768 if (this->_state == State::Closed)
769 {
771 return -1;
772 }
773
774 if (this->_state == State::Connected)
775 {
776 lastError = make_error_code (Errc::InUse);
777 return -1;
778 }
779
780 if ((this->_protocol.family () == AF_INET6) || (this->_protocol.family () == AF_INET))
781 {
782 this->setOption (Option::ReuseAddr, 1);
783 }
784
785 int result = setsockopt (this->_handle, SOL_SOCKET, SO_BINDTODEVICE, device.c_str (), device.size ());
786 if (result == -1)
787 {
788 lastError = std::error_code (errno, std::generic_category ());
789 return -1;
790 }
791
792 return 0;
793 }
794
800 virtual int connect (const Endpoint& endpoint)
801 {
802 if ((this->_state != State::Closed) && (this->_state != State::Disconnected))
803 {
804 lastError = make_error_code (Errc::InUse);
805 return -1;
806 }
807
808 if ((this->_state == State::Closed) && (this->open (endpoint.protocol ()) == -1))
809 {
810 return -1;
811 }
812
813 int result = ::connect (this->_handle, endpoint.addr (), endpoint.length ());
814
815 this->_state = State::Connecting;
816 this->_remote = endpoint;
817
818 if (result == -1)
819 {
820 lastError = std::error_code (errno, std::generic_category ());
821 if (lastError != std::errc::operation_in_progress)
822 {
823 this->close ();
824 }
825 return -1;
826 }
827
828 this->_state = State::Connected;
829
830 return 0;
831 }
832
837 virtual int disconnect ()
838 {
839 if (this->_state == State::Connected)
840 {
841 struct sockaddr_storage nullAddr;
842 ::memset (&nullAddr, 0, sizeof (nullAddr));
843
844 nullAddr.ss_family = AF_UNSPEC;
845
846 int result = ::connect (this->_handle, reinterpret_cast<struct sockaddr*> (&nullAddr),
847 sizeof (struct sockaddr_storage));
848 if (result == -1)
849 {
850 if (errno != EAFNOSUPPORT)
851 {
852 lastError = std::error_code (errno, std::generic_category ());
853 return -1;
854 }
855 }
856
857 this->_state = State::Disconnected;
858 this->_remote = {};
859 }
860
861 return 0;
862 }
863
867 virtual void close () noexcept override
868 {
870 this->_remote = {};
871 }
872
879 virtual int read (char* data, unsigned long maxSize) noexcept override
880 {
881 return BasicSocket<Protocol>::read (data, maxSize);
882 }
883
891 virtual int readFrom (char* data, unsigned long maxSize, Endpoint* endpoint = nullptr) noexcept
892 {
893 struct sockaddr_storage sa;
894 socklen_t sa_len = sizeof (struct sockaddr_storage);
895
896 int size = ::recvfrom (this->_handle, data, maxSize, 0, reinterpret_cast<struct sockaddr*> (&sa), &sa_len);
897 if (size < 1)
898 {
899 if (size == -1)
900 {
901 lastError = std::error_code (errno, std::generic_category ());
902 }
903 else
904 {
906 this->_state = State::Disconnected;
907 }
908
909 return -1;
910 }
911
912 if (endpoint != nullptr)
913 {
914 *endpoint = Endpoint (reinterpret_cast<struct sockaddr*> (&sa), sa_len);
915 }
916
917 return size;
918 }
919
926 virtual int write (const char* data, unsigned long maxSize) noexcept override
927 {
928 return BasicSocket<Protocol>::write (data, maxSize);
929 }
930
938 virtual int writeTo (const char* data, unsigned long maxSize, const Endpoint& endpoint) noexcept
939 {
940 if ((this->_state == State::Closed) && (this->open (endpoint.protocol ()) == -1))
941 {
942 return -1;
943 }
944
945 int result = ::sendto (this->_handle, data, maxSize, 0, endpoint.addr (), endpoint.length ());
946 if (result < 0)
947 {
948 lastError = std::error_code (errno, std::generic_category ());
949 return -1;
950 }
951
952 return result;
953 }
954
961 virtual int setOption (Option option, int value) noexcept override
962 {
963 if (this->_state == State::Closed)
964 {
966 return -1;
967 }
968
969 int optlevel, optname;
970
971 switch (option)
972 {
973 case Option::Ttl:
974 if (this->family () == AF_INET6)
975 {
976 optlevel = IPPROTO_IPV6;
977 optname = IPV6_UNICAST_HOPS;
978 }
979 else
980 {
981 optlevel = IPPROTO_IP;
982 optname = IP_TTL;
983 }
984 break;
985
986 case Option::MulticastLoop:
987 if (this->family () == AF_INET6)
988 {
989 optlevel = IPPROTO_IPV6;
990 optname = IPV6_MULTICAST_LOOP;
991 }
992 else
993 {
994 optlevel = IPPROTO_IP;
995 optname = IP_MULTICAST_LOOP;
996 }
997 break;
998
999 case Option::MulticastTtl:
1000 if (this->family () == AF_INET6)
1001 {
1002 optlevel = IPPROTO_IPV6;
1003 optname = IPV6_MULTICAST_HOPS;
1004 }
1005 else
1006 {
1007 optlevel = IPPROTO_IP;
1008 optname = IP_MULTICAST_TTL;
1009 }
1010 break;
1011
1012 case Option::PathMtuDiscover:
1013 if (this->family () == AF_INET6)
1014 {
1015 optlevel = IPPROTO_IPV6;
1016 optname = IPV6_MTU_DISCOVER;
1017 }
1018 else
1019 {
1020 optlevel = IPPROTO_IP;
1021 optname = IP_MTU_DISCOVER;
1022 }
1023 break;
1024
1025 case Option::RcvError:
1026 if (this->family () == AF_INET6)
1027 {
1028 optlevel = IPPROTO_IPV6;
1029 optname = IPV6_RECVERR;
1030 }
1031 else
1032 {
1033 optlevel = IPPROTO_IP;
1034 optname = IP_RECVERR;
1035 }
1036 break;
1037
1038 default:
1039 return BasicSocket<Protocol>::setOption (option, value);
1040 }
1041
1042 int result = ::setsockopt (this->_handle, optlevel, optname, &value, sizeof (value));
1043 if (result == -1)
1044 {
1045 lastError = std::error_code (errno, std::generic_category ());
1046 return -1;
1047 }
1048
1049 return 0;
1050 }
1051
1056 const Endpoint& remoteEndpoint () const noexcept
1057 {
1058 return this->_remote;
1059 }
1060
1065 virtual bool connected () noexcept
1066 {
1067 return (this->_state == State::Connected);
1068 }
1069
1074 int mtu () const
1075 {
1076 if (this->_state == State::Closed)
1077 {
1079 return -1;
1080 }
1081
1082 int result = -1, value = -1;
1083 socklen_t valueLen = sizeof (value);
1084
1085 if (this->_protocol.family () == AF_INET6)
1086 {
1087 result = ::getsockopt (this->_handle, IPPROTO_IPV6, IPV6_MTU, &value, &valueLen);
1088 }
1089 else if (this->_protocol.family () == AF_INET)
1090 {
1091 result = ::getsockopt (this->_handle, IPPROTO_IP, IP_MTU, &value, &valueLen);
1092 }
1093 else
1094 {
1096 return -1;
1097 }
1098
1099 if (result == -1)
1100 {
1101 lastError = std::error_code (errno, std::generic_category ());
1102 return -1;
1103 }
1104
1105 return value;
1106 }
1107
1112 int ttl () const noexcept
1113 {
1114 return this->_ttl;
1115 }
1116
1117 protected:
1120
1122 int _ttl = 60;
1123 };
1124
1131 template <class Protocol>
1133 {
1134 return a.handle () < b.handle ();
1135 }
1136
1140 template <class Protocol>
1142 {
1143 public:
1144 using Ptr = std::unique_ptr<BasicStreamSocket<Protocol>>;
1148 using Endpoint = typename Protocol::Endpoint;
1149
1157
1163 : BasicDatagramSocket<Protocol> (mode)
1164 {
1165 }
1166
1171 BasicStreamSocket (const BasicStreamSocket& other) = delete;
1172
1179
1185 : BasicDatagramSocket<Protocol> (std::move (other))
1186 {
1187 }
1188
1195 {
1197
1198 return *this;
1199 }
1200
1204 virtual ~BasicStreamSocket () = default;
1205
1211 virtual bool waitConnected (int timeout = 0)
1212 {
1213 if (this->_state != State::Connected)
1214 {
1215 if (this->_state != State::Connecting)
1216 {
1218 return false;
1219 }
1220
1221 if (!this->waitReadyWrite (timeout))
1222 {
1223 return false;
1224 }
1225
1226 return connected ();
1227 }
1228
1229 return true;
1230 }
1231
1236 virtual int disconnect () override
1237 {
1238 if (this->_state == State::Connected)
1239 {
1240 ::shutdown (this->_handle, SHUT_WR);
1241 this->_state = State::Disconnecting;
1242 }
1243
1244 if (this->_state == State::Disconnecting)
1245 {
1246 char buffer[4096];
1247 // closing before reading can make the client
1248 // not see all of our output.
1249 // we have to do a "lingering close"
1250 for (;;)
1251 {
1252 int result = this->read (buffer, sizeof (buffer));
1253 if (result <= 0)
1254 {
1255 if ((result == -1) && (lastError == Errc::TemporaryError))
1256 {
1257 return -1;
1258 }
1259
1260 break;
1261 }
1262 }
1263
1264 ::shutdown (this->_handle, SHUT_RD);
1265 this->_state = State::Disconnected;
1266 }
1267
1268 this->close ();
1269
1270 return 0;
1271 }
1272
1278 virtual bool waitDisconnected (int timeout = 0)
1279 {
1280 if ((this->_state != State::Disconnected) && (this->_state != State::Closed))
1281 {
1282 if (this->_state != State::Disconnecting)
1283 {
1285 return false;
1286 }
1287
1288 auto start = std::chrono::steady_clock::now ();
1289 int elapsed = 0;
1290
1291 while ((lastError == Errc::TemporaryError) && (elapsed <= timeout))
1292 {
1293 if (!this->waitReadyRead (timeout - elapsed))
1294 {
1295 return false;
1296 }
1297
1298 if (this->disconnect () == 0)
1299 {
1300 return true;
1301 }
1302
1303 if (timeout)
1304 {
1305 elapsed = std::chrono::duration_cast<std::chrono::milliseconds> (
1306 std::chrono::steady_clock::now () - start)
1307 .count ();
1308 }
1309 }
1310
1311 return false;
1312 }
1313
1314 return true;
1315 }
1316
1324 int readExactly (char* data, unsigned long size, int timeout = 0)
1325 {
1326 unsigned long numRead = 0;
1327
1328 while (numRead < size)
1329 {
1330 int result = this->read (data + numRead, size - numRead);
1331 if (result == -1)
1332 {
1333 if (lastError == Errc::TemporaryError)
1334 {
1335 if (this->waitReadyRead (timeout))
1336 continue;
1337 }
1338
1339 return -1;
1340 }
1341
1342 numRead += result;
1343 }
1344
1345 return 0;
1346 }
1347
1355 int readExactly (std::string& data, unsigned long size, int timeout = 0)
1356 {
1357 data.resize (size);
1358 return readExactly (&data[0], size, timeout);
1359 }
1360
1368 int writeExactly (const char* data, unsigned long size, int timeout = 0)
1369 {
1370 unsigned long numWrite = 0;
1371
1372 while (numWrite < size)
1373 {
1374 int result = this->write (data + numWrite, size - numWrite);
1375 if (result == -1)
1376 {
1377 if (lastError == Errc::TemporaryError)
1378 {
1379 if (this->waitReadyWrite (timeout))
1380 continue;
1381 }
1382
1383 return -1;
1384 }
1385
1386 numWrite += result;
1387 }
1388
1389 return 0;
1390 }
1391
1398 virtual int setOption (Option option, int value) noexcept override
1399 {
1400 if (this->_state == State::Closed)
1401 {
1403 return -1;
1404 }
1405
1406 int optlevel, optname;
1407
1408 switch (option)
1409 {
1410 case Option::NoDelay:
1411 optlevel = IPPROTO_TCP;
1412 optname = TCP_NODELAY;
1413 break;
1414
1415 case Option::KeepIdle:
1416 optlevel = IPPROTO_TCP;
1417 optname = TCP_KEEPIDLE;
1418 break;
1419
1420 case Option::KeepIntvl:
1421 optlevel = IPPROTO_TCP;
1422 optname = TCP_KEEPINTVL;
1423 break;
1424
1425 case Option::KeepCount:
1426 optlevel = IPPROTO_TCP;
1427 optname = TCP_KEEPCNT;
1428 break;
1429
1430 default:
1431 return BasicDatagramSocket<Protocol>::setOption (option, value);
1432 }
1433
1434 int result = ::setsockopt (this->_handle, optlevel, optname, &value, sizeof (value));
1435 if (result == -1)
1436 {
1437 lastError = std::error_code (errno, std::generic_category ());
1438 return -1;
1439 }
1440
1441 return 0;
1442 }
1443
1448 virtual bool connecting () const noexcept
1449 {
1450 return (this->_state == State::Connecting);
1451 }
1452
1457 virtual bool connected () noexcept override
1458 {
1459 if (this->_state == State::Connected)
1460 {
1461 return true;
1462 }
1463 else if (this->_state != State::Connecting)
1464 {
1465 return false;
1466 }
1467
1468 int optval;
1469 socklen_t optlen = sizeof (optval);
1470
1471 int result = ::getsockopt (this->_handle, SOL_SOCKET, SO_ERROR, &optval, &optlen);
1472 if ((result == -1) || (optval != 0))
1473 {
1474 return false;
1475 }
1476
1477 this->_state = State::Connected;
1478
1479 return true;
1480 }
1481
1483 friend class BasicStreamAcceptor<Protocol>;
1484 };
1485
1492 template <class Protocol>
1493 constexpr bool operator< (const BasicStreamSocket<Protocol>& a, const BasicStreamSocket<Protocol>& b) noexcept
1494 {
1495 return a.handle () < b.handle ();
1496 }
1497
1501 enum class TlsErrc
1502 {
1505 };
1506
1510 class TlsCategory : public std::error_category
1511 {
1512 public:
1517 virtual const char* name () const noexcept;
1518
1524 virtual std::string message (int code) const;
1525 };
1526
1531 const std::error_category& getTlsCategory ();
1532
1538 std::error_code make_error_code (TlsErrc code);
1539
1545 std::error_condition make_error_condition (TlsErrc code);
1546
1550 template <class Protocol>
1551 class BasicTlsSocket : public BasicStreamSocket<Protocol>
1552 {
1553 public:
1554 using Ptr = std::unique_ptr<BasicTlsSocket<Protocol>>;
1558 using Endpoint = typename Protocol::Endpoint;
1559
1565 {
1566 }
1567
1573 : BasicStreamSocket<Protocol> (mode)
1574 , _tlsContext (SSL_CTX_new (TLS_client_method ()))
1575 {
1576 // enable the OpenSSL bug workaround options.
1577 SSL_CTX_set_options (this->_tlsContext.get (), SSL_OP_ALL);
1578
1579 // disallow compression.
1580 SSL_CTX_set_options (this->_tlsContext.get (), SSL_OP_NO_COMPRESSION);
1581
1582 // disallow usage of SSLv2, SSLv3, TLSv1 and TLSv1.1 which are considered insecure.
1583 SSL_CTX_set_options (this->_tlsContext.get (),
1584 SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1);
1585
1586 // setup write mode.
1587 SSL_CTX_set_mode (this->_tlsContext.get (),
1588 SSL_MODE_ENABLE_PARTIAL_WRITE | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
1589
1590 // automatically renegotiates.
1591 SSL_CTX_set_mode (this->_tlsContext.get (), SSL_MODE_AUTO_RETRY);
1592
1593 // set session cache mode to client by default.
1594 SSL_CTX_set_session_cache_mode (this->_tlsContext.get (), SSL_SESS_CACHE_CLIENT);
1595
1596 // no verification by default.
1597 SSL_CTX_set_verify (this->_tlsContext.get (), SSL_VERIFY_NONE, nullptr);
1598
1599 // set default TLSv1.2 and below cipher suites.
1600 SSL_CTX_set_cipher_list (this->_tlsContext.get (), join::defaultCipher.c_str ());
1601
1602 // set default TLSv1.3 cipher suites.
1603 SSL_CTX_set_ciphersuites (this->_tlsContext.get (), join::defaultCipher_1_3.c_str ());
1604 }
1605
1611 : BasicTlsSocket (Mode::NonBlocking, std::move (tlsContext))
1612 {
1613 }
1614
1621 : BasicStreamSocket<Protocol> (mode)
1622 , _tlsContext (std::move (tlsContext))
1623 {
1624 if (this->_tlsContext == nullptr)
1625 {
1626 throw std::invalid_argument ("OpenSSL context is invalid");
1627 }
1628 }
1629
1634 BasicTlsSocket (const BasicTlsSocket& other) = delete;
1635
1642
1648 : BasicStreamSocket<Protocol> (std::move (other))
1649 , _tlsContext (std::move (other._tlsContext))
1650 , _tlsHandle (std::move (other._tlsHandle))
1651 , _tlsState (other._tlsState)
1652 {
1653 if (this->_tlsHandle)
1654 {
1655 SSL_set_app_data (this->_tlsHandle.get (), this);
1656 }
1657
1658 other._tlsState = TlsState::NonEncrypted;
1659 }
1660
1667 {
1668 BasicStreamSocket<Protocol>::operator= (std::move (other));
1669
1670 this->_tlsContext = std::move (other._tlsContext);
1671 this->_tlsHandle = std::move (other._tlsHandle);
1672 this->_tlsState = other._tlsState;
1673
1674 if (this->_tlsHandle)
1675 {
1676 SSL_set_app_data (this->_tlsHandle.get (), this);
1677 }
1678
1679 other._tlsState = TlsState::NonEncrypted;
1680
1681 return *this;
1682 }
1683
1687 virtual ~BasicTlsSocket () = default;
1688
1694 virtual int connectEncrypted (const Endpoint& endpoint)
1695 {
1696 if (this->connect (endpoint) == -1)
1697 {
1698 return -1;
1699 }
1700
1701 if (this->startEncryption () == -1)
1702 {
1703 if (lastError != Errc::TemporaryError)
1704 {
1705 this->close ();
1706 }
1707 return -1;
1708 }
1709
1710 return 0;
1711 }
1712
1718 {
1719 if (this->encrypted () == false)
1720 {
1721 this->_tlsHandle.reset (SSL_new (this->_tlsContext.get ()));
1722 if (this->_tlsHandle == nullptr)
1723 {
1724 lastError = make_error_code (Errc::OutOfMemory);
1725 return -1;
1726 }
1727
1728 if (SSL_set_fd (this->_tlsHandle.get (), this->_handle) != 1)
1729 {
1730 lastError = make_error_code (Errc::InvalidParam);
1731 this->_tlsHandle.reset ();
1732 return -1;
1733 }
1734
1735 if (SSL_is_server (this->_tlsHandle.get ()) == 0)
1736 {
1737 if (!this->_remote.hostname ().empty () &&
1738 (SSL_set_tlsext_host_name (this->_tlsHandle.get (), this->_remote.hostname ().c_str ()) != 1))
1739 {
1740 lastError = make_error_code (Errc::InvalidParam);
1741 this->_tlsHandle.reset ();
1742 return -1;
1743 }
1744
1745 SSL_set_connect_state (this->_tlsHandle.get ());
1746 }
1747 else
1748 {
1749 SSL_set_accept_state (this->_tlsHandle.get ());
1750 }
1751
1752 SSL_set_app_data (this->_tlsHandle.get (), this);
1753
1754#ifdef DEBUG
1755 SSL_set_info_callback (this->_tlsHandle.get (), infoWrapper);
1756#endif
1757
1758 return startHandshake ();
1759 }
1760
1761 return 0;
1762 }
1763
1769 virtual bool waitEncrypted (int timeout = 0)
1770 {
1771 if (this->encrypted () == false)
1772 {
1773 if (this->_state == State::Connecting)
1774 {
1775 if (!this->waitConnected (timeout))
1776 {
1777 return false;
1778 }
1779
1780 if (this->startEncryption () == 0)
1781 {
1782 return true;
1783 }
1784 }
1785
1786 while ((lastError == Errc::TemporaryError) &&
1787 (SSL_want_read (this->_tlsHandle.get ()) || SSL_want_write (this->_tlsHandle.get ())))
1788 {
1789 if (this->wait (SSL_want_read (this->_tlsHandle.get ()), SSL_want_write (this->_tlsHandle.get ()),
1790 timeout) == -1)
1791 {
1792 return false;
1793 }
1794
1795 if (this->startHandshake () == 0)
1796 {
1797 return true;
1798 }
1799 }
1800
1801 return false;
1802 }
1803
1804 return true;
1805 }
1806
1811 virtual int disconnect () override
1812 {
1813 if (this->encrypted ())
1814 {
1815 // check if the close_notify alert was already sent.
1816 if ((SSL_get_shutdown (this->_tlsHandle.get ()) & SSL_SENT_SHUTDOWN) == false)
1817 {
1818 // send the close_notify alert to the peer.
1819 int result = SSL_shutdown (this->_tlsHandle.get ());
1820 if (result < 0)
1821 {
1822 // shutdown was not successful.
1823 switch (SSL_get_error (this->_tlsHandle.get (), result))
1824 {
1825 case SSL_ERROR_WANT_READ:
1826 case SSL_ERROR_WANT_WRITE:
1827 // SSL_shutdown want read or want write.
1829 break;
1830 case SSL_ERROR_SYSCALL:
1831 // an error occurred at the socket level.
1832 switch (errno)
1833 {
1834 case 0:
1835 case ECONNRESET:
1836 case EPIPE:
1839 this->_state = State::Disconnected;
1840 break;
1841 default:
1842 lastError = std::error_code (errno, std::generic_category ());
1843 break;
1844 }
1845 break;
1846 default:
1847 // SSL protocol error.
1848#ifdef DEBUG
1849 std::cout << ERR_reason_error_string (ERR_get_error ()) << std::endl;
1850#endif
1852 break;
1853 }
1854
1855 return -1;
1856 }
1857 else if (result == 1)
1858 {
1859 // shutdown was successfully completed.
1860 // close_notify alert was sent and the peer's close_notify alert was received.
1862 }
1863 else
1864 {
1865 // shutdown is not yet finished.
1866 // the close_notify was sent but the peer did not send it back yet.
1867 // SSL_read must be called to do a bidirectional shutdown.
1868 }
1869 }
1870 }
1871
1873 }
1874
1878 virtual void close () noexcept override
1879 {
1882 this->_tlsHandle.reset ();
1883 }
1884
1890 virtual bool waitReadyRead (int timeout = 0) const noexcept override
1891 {
1892 if (this->encrypted () &&
1893 (SSL_want_read (this->_tlsHandle.get ()) || SSL_want_write (this->_tlsHandle.get ())))
1894 {
1895 return (this->wait (SSL_want_read (this->_tlsHandle.get ()), SSL_want_write (this->_tlsHandle.get ()),
1896 timeout) == 0);
1897 }
1898
1900 }
1901
1906 virtual int canRead () const noexcept override
1907 {
1908 if (this->encrypted ())
1909 {
1910 return SSL_pending (this->_tlsHandle.get ());
1911 }
1912
1914 }
1915
1922 virtual int read (char* data, unsigned long maxSize) noexcept override
1923 {
1924 if (this->encrypted ())
1925 {
1926 // read data.
1927 int result = SSL_read (this->_tlsHandle.get (), data, int (maxSize));
1928 if (result < 1)
1929 {
1930 switch (SSL_get_error (this->_tlsHandle.get (), result))
1931 {
1932 case SSL_ERROR_WANT_READ:
1933 case SSL_ERROR_WANT_WRITE:
1934 case SSL_ERROR_WANT_X509_LOOKUP:
1935 // SSL_read want read, want write or want lookup.
1937 break;
1938 case SSL_ERROR_ZERO_RETURN:
1939 // a close notify alert was received.
1940 // we have to answer by sending a close notify alert too.
1942 if (SSL_get_shutdown (this->_tlsHandle.get ()) & SSL_SENT_SHUTDOWN)
1943 {
1945 }
1946 break;
1947 case SSL_ERROR_SYSCALL:
1948 // an error occurred at the socket level.
1949 switch (errno)
1950 {
1951 case 0:
1952 case ECONNRESET:
1953 case EPIPE:
1956 this->_state = State::Disconnected;
1957 break;
1958 default:
1959 lastError = std::error_code (errno, std::generic_category ());
1960 break;
1961 }
1962 break;
1963 default:
1964 // SSL protocol error.
1965#ifdef DEBUG
1966 std::cout << ERR_reason_error_string (ERR_get_error ()) << std::endl;
1967#endif
1969 break;
1970 }
1971
1972 return -1;
1973 }
1974
1975 return result;
1976 }
1977
1978 return BasicStreamSocket<Protocol>::read (data, maxSize);
1979 }
1980
1986 virtual bool waitReadyWrite (int timeout = 0) const noexcept override
1987 {
1988 if (this->encrypted () &&
1989 (SSL_want_read (this->_tlsHandle.get ()) || SSL_want_write (this->_tlsHandle.get ())))
1990 {
1991 return (this->wait (SSL_want_read (this->_tlsHandle.get ()), SSL_want_write (this->_tlsHandle.get ()),
1992 timeout) == 0);
1993 }
1994
1996 }
1997
2004 virtual int write (const char* data, unsigned long maxSize) noexcept override
2005 {
2006 if (this->encrypted ())
2007 {
2008 // write data.
2009 int result = SSL_write (this->_tlsHandle.get (), data, int (maxSize));
2010 if (result < 1)
2011 {
2012 switch (SSL_get_error (this->_tlsHandle.get (), result))
2013 {
2014 case SSL_ERROR_WANT_READ:
2015 case SSL_ERROR_WANT_WRITE:
2016 case SSL_ERROR_WANT_X509_LOOKUP:
2017 // SSL_write want read, want write or want lookup.
2019 break;
2020 case SSL_ERROR_ZERO_RETURN:
2021 // a close notify alert was received.
2022 // we have to answer by sending a close notify alert too.
2024 if (SSL_get_shutdown (this->_tlsHandle.get ()) & SSL_SENT_SHUTDOWN)
2025 {
2027 }
2028 break;
2029 case SSL_ERROR_SYSCALL:
2030 // an error occurred at the socket level.
2031 switch (errno)
2032 {
2033 case 0:
2034 case ECONNRESET:
2035 case EPIPE:
2038 this->_state = State::Disconnected;
2039 break;
2040 default:
2041 lastError = std::error_code (errno, std::generic_category ());
2042 break;
2043 }
2044 break;
2045 default:
2046 // SSL protocol error.
2047#ifdef DEBUG
2048 std::cout << ERR_reason_error_string (ERR_get_error ()) << std::endl;
2049#endif
2051 break;
2052 }
2053
2054 return -1;
2055 }
2056
2057 return result;
2058 }
2059
2060 return BasicStreamSocket<Protocol>::write (data, maxSize);
2061 }
2062
2067 virtual bool encrypted () const noexcept override
2068 {
2069 return (this->_tlsState == TlsState::Encrypted);
2070 }
2071
2078 int setCertificate (const std::string& cert, const std::string& key = "")
2079 {
2080 if (((this->_tlsHandle)
2081 ? SSL_use_certificate_file (this->_tlsHandle.get (), cert.c_str (), SSL_FILETYPE_PEM)
2082 : SSL_CTX_use_certificate_file (this->_tlsContext.get (), cert.c_str (), SSL_FILETYPE_PEM)) == 0)
2083 {
2084 lastError = make_error_code (Errc::InvalidParam);
2085 return -1;
2086 }
2087
2088 if (key.size ())
2089 {
2090 if (((this->_tlsHandle)
2091 ? SSL_use_PrivateKey_file (this->_tlsHandle.get (), key.c_str (), SSL_FILETYPE_PEM)
2092 : SSL_CTX_use_PrivateKey_file (this->_tlsContext.get (), key.c_str (), SSL_FILETYPE_PEM)) == 0)
2093 {
2094 lastError = make_error_code (Errc::InvalidParam);
2095 return -1;
2096 }
2097 }
2098
2099 if (((this->_tlsHandle) ? SSL_check_private_key (this->_tlsHandle.get ())
2100 : SSL_CTX_check_private_key (this->_tlsContext.get ())) == 0)
2101 {
2102 lastError = make_error_code (Errc::InvalidParam);
2103 return -1;
2104 }
2105
2106 return 0;
2107 }
2108
2114 int setCaPath (const std::string& caPath)
2115 {
2116 struct stat st;
2117 if (stat (caPath.c_str (), &st) != 0 || !S_ISDIR (st.st_mode) ||
2118 SSL_CTX_load_verify_locations (this->_tlsContext.get (), nullptr, caPath.c_str ()) == 0)
2119 {
2120 lastError = make_error_code (Errc::InvalidParam);
2121 return -1;
2122 }
2123
2124 return 0;
2125 }
2126
2132 int setCaFile (const std::string& caFile)
2133 {
2134 struct stat st;
2135 if (stat (caFile.c_str (), &st) != 0 || !S_ISREG (st.st_mode) ||
2136 SSL_CTX_load_verify_locations (this->_tlsContext.get (), caFile.c_str (), nullptr) == 0)
2137 {
2138 lastError = make_error_code (Errc::InvalidParam);
2139 return -1;
2140 }
2141
2142 return 0;
2143 }
2144
2150 void setVerify (bool verify, int depth = -1) noexcept
2151 {
2152 if (verify == true)
2153 {
2154 SSL_CTX_set_verify (this->_tlsContext.get (), SSL_VERIFY_PEER, verifyWrapper);
2155 SSL_CTX_set_verify_depth (this->_tlsContext.get (), depth);
2156 }
2157 else
2158 {
2159 SSL_CTX_set_verify (this->_tlsContext.get (), SSL_VERIFY_NONE, nullptr);
2160 }
2161 }
2162
2168 int setCipher (const std::string& cipher)
2169 {
2170 if (((this->_tlsHandle) ? SSL_set_cipher_list (this->_tlsHandle.get (), cipher.c_str ())
2171 : SSL_CTX_set_cipher_list (this->_tlsContext.get (), cipher.c_str ())) == 0)
2172 {
2173 lastError = make_error_code (Errc::InvalidParam);
2174 return -1;
2175 }
2176
2177 return 0;
2178 }
2179
2185 int setCipher_1_3 (const std::string& cipher)
2186 {
2187 if (((this->_tlsHandle) ? SSL_set_ciphersuites (this->_tlsHandle.get (), cipher.c_str ())
2188 : SSL_CTX_set_ciphersuites (this->_tlsContext.get (), cipher.c_str ())) == 0)
2189 {
2190 lastError = make_error_code (Errc::InvalidParam);
2191 return -1;
2192 }
2193
2194 return 0;
2195 }
2196
2202 int setAlpnProtocols (const std::vector<std::string>& protocols)
2203 {
2204 std::vector<uint8_t> wire;
2205 wire.reserve (256);
2206
2207 for (auto const& proto : protocols)
2208 {
2209 wire.push_back (static_cast<uint8_t> (proto.size ()));
2210 wire.insert (wire.end (), proto.begin (), proto.end ());
2211 }
2212
2213 if (SSL_CTX_set_alpn_protos (this->_tlsContext.get (), wire.data (),
2214 static_cast<unsigned int> (wire.size ())) != 0)
2215 {
2216 lastError = make_error_code (Errc::InvalidParam);
2217 return -1;
2218 }
2219
2220 return 0;
2221 }
2222
2223 protected:
2232
2238 {
2239 // start the SSL handshake.
2240 int result = SSL_do_handshake (this->_tlsHandle.get ());
2241 if (result < 1)
2242 {
2243 switch (SSL_get_error (this->_tlsHandle.get (), result))
2244 {
2245 case SSL_ERROR_WANT_READ:
2246 case SSL_ERROR_WANT_WRITE:
2247 case SSL_ERROR_WANT_X509_LOOKUP:
2248 // SSL_do_handshake want read or want write.
2250 break;
2251 case SSL_ERROR_ZERO_RETURN:
2252 // a close notify alert was received.
2253 // we have to answer by sending a close notify alert too.
2255 break;
2256 case SSL_ERROR_SYSCALL:
2257 // an error occurred at the socket level.
2258 switch (errno)
2259 {
2260 case 0:
2261 case ECONNRESET:
2262 case EPIPE:
2264 this->_state = State::Disconnected;
2265 break;
2266 default:
2267 lastError = std::error_code (errno, std::generic_category ());
2268 break;
2269 }
2270 break;
2271 default:
2272 // SSL protocol error.
2273#ifdef DEBUG
2274 std::cout << ERR_reason_error_string (ERR_get_error ()) << std::endl;
2275#endif
2277 break;
2278 }
2279
2280 return -1;
2281 }
2282
2284
2285 return 0;
2286 }
2287
2294 static void infoWrapper (const SSL* ssl, int where, int ret)
2295 {
2296 assert (ssl);
2297 static_cast<BasicTlsSocket<Protocol>*> (SSL_get_app_data (ssl))->infoCallback (where, ret);
2298 }
2299
2305 void infoCallback (int where, int ret) const
2306 {
2307 if (where & SSL_CB_ALERT)
2308 {
2309 std::cout << "SSL/TLS Alert ";
2310 (where & SSL_CB_READ) ? std::cout << "[read] " : std::cout << "[write] ";
2311 std::cout << SSL_alert_type_string_long (ret) << ":";
2312 std::cout << SSL_alert_desc_string_long (ret);
2313 std::cout << std::endl;
2314 }
2315 else if (where & SSL_CB_LOOP)
2316 {
2317 std::cout << "SSL/TLS State ";
2318 (SSL_in_connect_init (this->_tlsHandle.get ())) ? std::cout << "[connect] "
2319 : (SSL_in_accept_init (this->_tlsHandle.get ())) ? std::cout << "[accept] "
2320 : std::cout << "[undefined] ";
2321 std::cout << SSL_state_string_long (this->_tlsHandle.get ());
2322 std::cout << std::endl;
2323 }
2324 else if (where & SSL_CB_HANDSHAKE_START)
2325 {
2326 std::cout << "SSL/TLS Handshake [Start] " << SSL_state_string_long (this->_tlsHandle.get ())
2327 << std::endl;
2328 }
2329 else if (where & SSL_CB_HANDSHAKE_DONE)
2330 {
2331 std::cout << "SSL/TLS Handshake [Done] " << SSL_state_string_long (this->_tlsHandle.get ())
2332 << std::endl;
2333 std::cout << SSL_CTX_sess_number (this->_tlsContext.get ()) << " items in the session cache"
2334 << std::endl;
2335 std::cout << SSL_CTX_sess_connect (this->_tlsContext.get ()) << " client connects" << std::endl;
2336 std::cout << SSL_CTX_sess_connect_good (this->_tlsContext.get ()) << " client connects that finished"
2337 << std::endl;
2338 std::cout << SSL_CTX_sess_connect_renegotiate (this->_tlsContext.get ())
2339 << " client renegotiations requested" << std::endl;
2340 std::cout << SSL_CTX_sess_accept (this->_tlsContext.get ()) << " server connects" << std::endl;
2341 std::cout << SSL_CTX_sess_accept_good (this->_tlsContext.get ()) << " server connects that finished"
2342 << std::endl;
2343 std::cout << SSL_CTX_sess_accept_renegotiate (this->_tlsContext.get ())
2344 << " server renegotiations requested" << std::endl;
2345 std::cout << SSL_CTX_sess_hits (this->_tlsContext.get ()) << " session cache hits" << std::endl;
2346 std::cout << SSL_CTX_sess_cb_hits (this->_tlsContext.get ()) << " external session cache hits"
2347 << std::endl;
2348 std::cout << SSL_CTX_sess_misses (this->_tlsContext.get ()) << " session cache misses" << std::endl;
2349 std::cout << SSL_CTX_sess_timeouts (this->_tlsContext.get ()) << " session cache timeouts" << std::endl;
2350 std::cout << "negotiated " << SSL_get_cipher (this->_tlsHandle.get ()) << " cipher suite" << std::endl;
2351 }
2352 }
2353
2360 static int verifyWrapper (int preverified, X509_STORE_CTX* context)
2361 {
2362 SSL* ssl = static_cast<SSL*> (X509_STORE_CTX_get_ex_data (context, SSL_get_ex_data_X509_STORE_CTX_idx ()));
2363
2364 assert (ssl);
2365 return static_cast<BasicTlsSocket<Protocol>*> (SSL_get_app_data (ssl))
2366 ->verifyCallback (preverified, context);
2367 }
2368
2375 int verifyCallback (int preverified, X509_STORE_CTX* context) const
2376 {
2377 int maxDepth = SSL_get_verify_depth (this->_tlsHandle.get ());
2378 int dpth = X509_STORE_CTX_get_error_depth (context);
2379
2380#ifdef DEBUG
2381 std::cout << "verification started at depth=" << dpth << std::endl;
2382#endif
2383
2384 // catch a too long certificate chain.
2385 if ((maxDepth >= 0) && (dpth > maxDepth))
2386 {
2387 preverified = 0;
2388 X509_STORE_CTX_set_error (context, X509_V_ERR_CERT_CHAIN_TOO_LONG);
2389 }
2390
2391 if (!preverified)
2392 {
2393#ifdef DEBUG
2394 std::cout << "verification failed at depth=" << dpth << " - "
2395 << X509_verify_cert_error_string (X509_STORE_CTX_get_error (context)) << std::endl;
2396#endif
2397 return 0;
2398 }
2399
2400 // check the certificate host name.
2401 if (!verifyCert (context))
2402 {
2403#ifdef DEBUG
2404 std::cout << "rejected by CERT at depth=" << dpth << std::endl;
2405#endif
2406 return 0;
2407 }
2408
2409 // check the revocation list.
2410 /*if (!verifyCrl (context))
2411 {
2412 #ifdef DEBUG
2413 std::cout << "rejected by CRL at depth=" << dpth << std::endl;
2414 #endif
2415 return 0;
2416 }*/
2417
2418 // check ocsp.
2419 /*if (!verifyOcsp (context))
2420 {
2421 #ifdef DEBUG
2422 std::cout << "rejected by OCSP at depth=" << dpth << std::endl;
2423 #endif
2424 return 0;
2425 }*/
2426
2427#ifdef DEBUG
2428 std::cout << "certificate accepted at depth=" << dpth << std::endl;
2429#endif
2430
2431 return 1;
2432 }
2433
2439 int verifyCert (X509_STORE_CTX* context) const
2440 {
2441 int depth = X509_STORE_CTX_get_error_depth (context);
2442 X509* cert = X509_STORE_CTX_get_current_cert (context);
2443
2444 char buf[256];
2445 X509_NAME_oneline (X509_get_subject_name (cert), buf, sizeof (buf));
2446#ifdef DEBUG
2447 std::cout << "subject=" << buf << std::endl;
2448#endif
2449
2450 // check the certificate host name
2451 if (depth == 0)
2452 {
2453 // confirm a match between the hostname and the hostnames listed in the certificate.
2454 if (!checkHostName (cert))
2455 {
2456#ifdef DEBUG
2457 std::cout << "no match for hostname in the certificate" << std::endl;
2458#endif
2459 return 0;
2460 }
2461 }
2462
2463 return 1;
2464 }
2465
2471 bool checkHostName (X509* certificate) const
2472 {
2473 bool match = false;
2474
2475 // get alternative names.
2476 join::StackOfGeneralNamePtr altnames (reinterpret_cast<STACK_OF (GENERAL_NAME)*> (
2477 X509_get_ext_d2i (certificate, NID_subject_alt_name, 0, 0)));
2478 if (altnames)
2479 {
2480 for (int i = 0; (i < sk_GENERAL_NAME_num (altnames.get ())) && !match; ++i)
2481 {
2482 // get a handle to alternative name.
2483 GENERAL_NAME* current_name = sk_GENERAL_NAME_value (altnames.get (), i);
2484
2485 if (current_name->type == GEN_DNS)
2486 {
2487 // get data and length.
2488 const char* host = reinterpret_cast<const char*> (ASN1_STRING_get0_data (current_name->d.ia5));
2489 size_t len = size_t (ASN1_STRING_length (current_name->d.ia5));
2490 std::string pattern (host, host + len), serverName (this->_remote.hostname ());
2491
2492 // strip off trailing dots.
2493 if (pattern.back () == '.')
2494 {
2495 pattern.pop_back ();
2496 }
2497
2498 if (serverName.back () == '.')
2499 {
2500 serverName.pop_back ();
2501 }
2502
2503 // compare to pattern.
2504 if (fnmatch (pattern.c_str (), serverName.c_str (), 0) == 0)
2505 {
2506 // an alternative name matched the server hostname.
2507 match = true;
2508 }
2509 }
2510 }
2511 }
2512
2513 return match;
2514 }
2515
2521 /*int verifyCrl ([[maybe_unused]]X509_STORE_CTX *context) const
2522 {
2523 return 1;
2524 }*/
2525
2531 /*int verifyOcsp ([[maybe_unused]]X509_STORE_CTX *context) const
2532 {
2533 return 1;
2534 }*/
2535
2538
2541
2544
2546 friend class BasicTlsAcceptor<Protocol>;
2547 };
2548
2555 template <class Protocol>
2556 constexpr bool operator< (const BasicTlsSocket<Protocol>& a, const BasicTlsSocket<Protocol>& b) noexcept
2557 {
2558 return a.handle () < b.handle ();
2559 }
2560}
2561
2562namespace std
2563{
2565 template <>
2566 struct is_error_condition_enum<join::TlsErrc> : public true_type
2567 {
2568 };
2569}
2570
2571#endif
basic datagram socket class.
Definition socket.hpp:645
virtual ~BasicDatagramSocket()=default
Destroy the instance.
BasicDatagramSocket(const BasicDatagramSocket &other)=delete
Copy constructor.
typename BasicSocket< Protocol >::Option Option
Definition socket.hpp:649
BasicDatagramSocket(Mode mode, int ttl=60)
Create instance specifying the mode.
Definition socket.hpp:665
const Endpoint & remoteEndpoint() const noexcept
determine the remote endpoint associated with this socket.
Definition socket.hpp:1056
virtual int connect(const Endpoint &endpoint)
make a connection to the given endpoint.
Definition socket.hpp:800
virtual int bindToDevice(const std::string &device) noexcept
assigns the specified device to the socket.
Definition socket.hpp:766
std::unique_ptr< BasicDatagramSocket< Protocol > > Ptr
Definition socket.hpp:647
virtual bool connected() noexcept
check if the socket is connected.
Definition socket.hpp:1065
virtual int setOption(Option option, int value) noexcept override
set the given option to the given value.
Definition socket.hpp:961
Endpoint _remote
remote endpoint.
Definition socket.hpp:1119
virtual int read(char *data, unsigned long maxSize) noexcept override
read data.
Definition socket.hpp:879
int _ttl
packet time to live.
Definition socket.hpp:1122
virtual int write(const char *data, unsigned long maxSize) noexcept override
write data.
Definition socket.hpp:926
virtual void close() noexcept override
close the socket handle.
Definition socket.hpp:867
typename Protocol::Endpoint Endpoint
Definition socket.hpp:651
BasicDatagramSocket(int ttl=60)
Default constructor.
Definition socket.hpp:656
typename BasicSocket< Protocol >::State State
Definition socket.hpp:650
int mtu() const
get socket mtu.
Definition socket.hpp:1074
BasicDatagramSocket & operator=(const BasicDatagramSocket &other)=delete
Copy assignment operator.
virtual int readFrom(char *data, unsigned long maxSize, Endpoint *endpoint=nullptr) noexcept
read data on the socket.
Definition socket.hpp:891
virtual int writeTo(const char *data, unsigned long maxSize, const Endpoint &endpoint) noexcept
write data on the socket.
Definition socket.hpp:938
virtual int disconnect()
shutdown the connection.
Definition socket.hpp:837
typename BasicSocket< Protocol >::Mode Mode
Definition socket.hpp:648
virtual int open(const Protocol &protocol=Protocol()) noexcept override
open socket using the given protocol.
Definition socket.hpp:723
BasicDatagramSocket(BasicDatagramSocket &&other)
Move constructor.
Definition socket.hpp:688
int ttl() const noexcept
returns the Time-To-Live value.
Definition socket.hpp:1112
basic socket class.
Definition socket.hpp:60
bool opened() const noexcept
check if the socket is opened.
Definition socket.hpp:489
BasicSocket & operator=(const BasicSocket &other)=delete
copy assignment operator.
virtual int open(const Protocol &protocol=Protocol()) noexcept
open socket using the given protocol.
Definition socket.hpp:194
static uint16_t checksum(const uint16_t *data, size_t len, uint16_t current=0)
get standard 1s complement checksum.
Definition socket.hpp:546
void setMode(Mode mode) noexcept
set the socket to the non-blocking or blocking mode.
Definition socket.hpp:380
State
socket states.
Definition socket.hpp:102
@ Disconnected
Definition socket.hpp:106
@ Connecting
Definition socket.hpp:103
@ Disconnecting
Definition socket.hpp:105
@ Closed
Definition socket.hpp:107
@ Connected
Definition socket.hpp:104
Protocol _protocol
protocol.
Definition socket.hpp:625
virtual int setOption(Option option, int value) noexcept
set the given option to the given value.
Definition socket.hpp:407
Mode _mode
socket mode.
Definition socket.hpp:619
virtual void close() noexcept
close the socket.
Definition socket.hpp:223
int family() const noexcept
get socket address family.
Definition socket.hpp:507
Mode
socket modes.
Definition socket.hpp:69
@ Blocking
Definition socket.hpp:70
@ NonBlocking
Definition socket.hpp:71
virtual bool waitReadyRead(int timeout=0) const noexcept
block until new data is available for reading.
Definition socket.hpp:294
std::unique_ptr< BasicSocket< Protocol > > Ptr
Definition socket.hpp:62
virtual ~BasicSocket()
destroy the socket instance.
Definition socket.hpp:181
virtual int bind(const Endpoint &endpoint) noexcept
assigns the specified endpoint to the socket.
Definition socket.hpp:238
BasicSocket()
default constructor.
Definition socket.hpp:113
int type() const noexcept
get the protocol communication semantic.
Definition socket.hpp:516
State _state
socket state.
Definition socket.hpp:616
typename Protocol::Endpoint Endpoint
Definition socket.hpp:63
virtual bool encrypted() const noexcept
check if the socket is secure.
Definition socket.hpp:498
Option
socket options.
Definition socket.hpp:78
@ MulticastTtl
Definition socket.hpp:92
@ SndBuffer
Definition socket.hpp:84
@ Ttl
Definition socket.hpp:90
@ Broadcast
Definition socket.hpp:89
@ RcvError
Definition socket.hpp:94
@ ReusePort
Definition socket.hpp:88
@ MulticastLoop
Definition socket.hpp:91
@ KeepAlive
Definition socket.hpp:80
@ ReuseAddr
Definition socket.hpp:87
@ KeepCount
Definition socket.hpp:83
@ KeepIntvl
Definition socket.hpp:82
@ AuxData
Definition socket.hpp:95
@ PathMtuDiscover
Definition socket.hpp:93
@ NoDelay
Definition socket.hpp:79
@ TimeStamp
Definition socket.hpp:86
@ KeepIdle
Definition socket.hpp:81
@ RcvBuffer
Definition socket.hpp:85
Endpoint localEndpoint() const noexcept
determine the local endpoint associated with this socket.
Definition socket.hpp:472
virtual bool waitReadyWrite(int timeout=0) const noexcept
block until at least one byte can be written.
Definition socket.hpp:341
BasicSocket(BasicSocket &&other)
move constructor.
Definition socket.hpp:144
int handle() const noexcept
get socket native handle.
Definition socket.hpp:534
virtual int read(char *data, unsigned long maxSize) noexcept
read data.
Definition socket.hpp:305
virtual int write(const char *data, unsigned long maxSize) noexcept
write data.
Definition socket.hpp:352
BasicSocket(Mode mode)
create socket instance specifying the mode.
Definition socket.hpp:122
int wait(bool wantRead, bool wantWrite, int timeout) const noexcept
wait for the socket handle to become ready.
Definition socket.hpp:579
int protocol() const noexcept
get socket protocol.
Definition socket.hpp:525
int _handle
socket handle.
Definition socket.hpp:622
virtual int canRead() const noexcept
get the number of readable bytes.
Definition socket.hpp:275
BasicSocket(const BasicSocket &other)=delete
copy constructor.
basic stream acceptor class.
Definition protocol.hpp:52
basic stream socket class.
Definition socket.hpp:1142
virtual ~BasicStreamSocket()=default
destroy the instance.
virtual int setOption(Option option, int value) noexcept override
set the given option to the given value.
Definition socket.hpp:1398
virtual bool connecting() const noexcept
check if the socket is connecting.
Definition socket.hpp:1448
BasicStreamSocket(const BasicStreamSocket &other)=delete
copy constructor.
BasicStreamSocket(BasicStreamSocket &&other)
move constructor.
Definition socket.hpp:1184
virtual bool waitConnected(int timeout=0)
block until connected.
Definition socket.hpp:1211
int readExactly(std::string &data, unsigned long size, int timeout=0)
read data until size is reached or an error occurred.
Definition socket.hpp:1355
typename BasicDatagramSocket< Protocol >::Mode Mode
Definition socket.hpp:1145
int writeExactly(const char *data, unsigned long size, int timeout=0)
write data until size is reached or an error occurred.
Definition socket.hpp:1368
BasicStreamSocket(Mode mode)
create instance specifying the mode.
Definition socket.hpp:1162
typename Protocol::Endpoint Endpoint
Definition socket.hpp:1148
std::unique_ptr< BasicStreamSocket< Protocol > > Ptr
Definition socket.hpp:1144
BasicStreamSocket & operator=(const BasicStreamSocket &other)=delete
copy assignment operator.
typename BasicDatagramSocket< Protocol >::Option Option
Definition socket.hpp:1146
virtual bool connected() noexcept override
check if the socket is connected.
Definition socket.hpp:1457
virtual bool waitDisconnected(int timeout=0)
wait until the connection as been shut down.
Definition socket.hpp:1278
virtual int disconnect() override
shutdown the connection.
Definition socket.hpp:1236
BasicStreamSocket()
default constructor.
Definition socket.hpp:1153
int readExactly(char *data, unsigned long size, int timeout=0)
read data until size is reached or an error occurred.
Definition socket.hpp:1324
typename BasicDatagramSocket< Protocol >::State State
Definition socket.hpp:1147
basic TLS acceptor class.
Definition protocol.hpp:54
basic TLS socket class.
Definition socket.hpp:1552
BasicTlsSocket()
default constructor.
Definition socket.hpp:1563
BasicTlsSocket & operator=(const BasicTlsSocket &other)=delete
copy assignment operator.
int setCaFile(const std::string &caFile)
set the location of the trusted CA certificate file.
Definition socket.hpp:2132
virtual int canRead() const noexcept override
get the number of readable bytes.
Definition socket.hpp:1906
int startEncryption()
start socket encryption (perform TLS handshake).
Definition socket.hpp:1717
virtual bool encrypted() const noexcept override
check if the socket is secure.
Definition socket.hpp:2067
BasicTlsSocket(Mode mode, join::SslCtxPtr tlsContext)
Create socket instance specifying the socket mode and TLS context.
Definition socket.hpp:1620
join::SslCtxPtr _tlsContext
verify certificate revocation using CRL.
Definition socket.hpp:2537
virtual int read(char *data, unsigned long maxSize) noexcept override
read data on the socket.
Definition socket.hpp:1922
join::SslPtr _tlsHandle
TLS handle.
Definition socket.hpp:2540
virtual bool waitEncrypted(int timeout=0)
wait until TLS handshake is performed or timeout occur (non blocking socket).
Definition socket.hpp:1769
int verifyCallback(int preverified, X509_STORE_CTX *context) const
trusted CA certificates verification callback.
Definition socket.hpp:2375
virtual bool waitReadyRead(int timeout=0) const noexcept override
block until new data is available for reading.
Definition socket.hpp:1890
int verifyCert(X509_STORE_CTX *context) const
verify certificate validity.
Definition socket.hpp:2439
void infoCallback(int where, int ret) const
state information callback.
Definition socket.hpp:2305
int setAlpnProtocols(const std::vector< std::string > &protocols)
set the ALPN protocols list.
Definition socket.hpp:2202
std::unique_ptr< BasicTlsSocket< Protocol > > Ptr
Definition socket.hpp:1554
BasicTlsSocket(Mode mode)
create instance specifying the mode.
Definition socket.hpp:1572
TlsState
TLS state.
Definition socket.hpp:2228
@ Encrypted
Definition socket.hpp:2229
@ NonEncrypted
Definition socket.hpp:2230
TlsState _tlsState
TLS state.
Definition socket.hpp:2543
bool checkHostName(X509 *certificate) const
confirm a match between the hostname contacted and the hostnames listed in the certificate.
Definition socket.hpp:2471
virtual int disconnect() override
shutdown the connection.
Definition socket.hpp:1811
int setCertificate(const std::string &cert, const std::string &key="")
set the certificate and the private key.
Definition socket.hpp:2078
void setVerify(bool verify, int depth=-1) noexcept
Enable/Disable the verification of the peer certificate.
Definition socket.hpp:2150
virtual bool waitReadyWrite(int timeout=0) const noexcept override
block until until at least one byte can be written on the socket.
Definition socket.hpp:1986
int setCipher(const std::string &cipher)
set the cipher list (TLSv1.2 and below).
Definition socket.hpp:2168
typename BasicStreamSocket< Protocol >::State State
Definition socket.hpp:1557
typename Protocol::Endpoint Endpoint
Definition socket.hpp:1558
virtual void close() noexcept override
close the socket handle.
Definition socket.hpp:1878
int setCipher_1_3(const std::string &cipher)
set the cipher list (TLSv1.3).
Definition socket.hpp:2185
virtual ~BasicTlsSocket()=default
destroy the instance.
int setCaPath(const std::string &caPath)
set the location of the trusted CA certificates.
Definition socket.hpp:2114
BasicTlsSocket(const BasicTlsSocket &other)=delete
copy constructor.
int startHandshake()
Start SSL handshake.
Definition socket.hpp:2237
virtual int connectEncrypted(const Endpoint &endpoint)
make an encrypted connection to the given endpoint.
Definition socket.hpp:1694
typename BasicStreamSocket< Protocol >::Option Option
Definition socket.hpp:1556
static void infoWrapper(const SSL *ssl, int where, int ret)
c style callback wrapper for the state information callback.
Definition socket.hpp:2294
virtual int write(const char *data, unsigned long maxSize) noexcept override
write data on the socket.
Definition socket.hpp:2004
BasicTlsSocket(join::SslCtxPtr tlsContext)
create instance specifying TLS context.
Definition socket.hpp:1610
typename BasicStreamSocket< Protocol >::Mode Mode
Definition socket.hpp:1555
static int verifyWrapper(int preverified, X509_STORE_CTX *context)
c style callback wrapper for the Trusted CA certificates verification callback.
Definition socket.hpp:2360
BasicTlsSocket(BasicTlsSocket &&other)
move constructor.
Definition socket.hpp:1647
Event handler interface class.
Definition reactor.hpp:46
TLS error category.
Definition socket.hpp:1511
virtual std::string message(int code) const
translate digest error code to human readable error string.
Definition socket.cpp:44
virtual const char * name() const noexcept
get digest error category name.
Definition socket.cpp:35
const std::string key(65, 'a')
key.
Definition acceptor.hpp:32
bool operator<(const BasicUnixEndpoint< Protocol > &a, const BasicUnixEndpoint< Protocol > &b) noexcept
compare if endpoint is lower.
Definition endpoint.hpp:207
std::unique_ptr< SSL_CTX, SslCtxDelete > SslCtxPtr
Definition openssl.hpp:240
std::unique_ptr< STACK_OF(GENERAL_NAME), StackOfGeneralNameDelete > StackOfGeneralNamePtr
Definition openssl.hpp:210
const std::string defaultCipher_1_3
Definition openssl.cpp:40
TlsErrc
TLS error codes.
Definition socket.hpp:1502
const std::error_category & getTlsCategory()
get error category.
Definition socket.cpp:61
std::error_code make_error_code(join::Errc code) noexcept
Create an std::error_code object.
Definition error.cpp:150
const std::string defaultCipher
Definition openssl.cpp:36
std::unique_ptr< SSL, SslDelete > SslPtr
Definition openssl.hpp:225
std::error_condition make_error_condition(join::Errc code) noexcept
Create an std::error_condition object.
Definition error.cpp:159
Definition error.hpp:137