committed by
GitHub
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 2413 additions and 282 deletions
-
16CMakeLists.txt
-
16src/common/socket_types.h
-
18src/core/CMakeLists.txt
-
120src/core/hle/service/sockets/bsd.cpp
-
13src/core/hle/service/sockets/bsd.h
-
58src/core/hle/service/sockets/nsd.cpp
-
4src/core/hle/service/sockets/nsd.h
-
388src/core/hle/service/sockets/sfdnsres.cpp
-
3src/core/hle/service/sockets/sfdnsres.h
-
33src/core/hle/service/sockets/sockets.h
-
114src/core/hle/service/sockets/sockets_translate.cpp
-
17src/core/hle/service/sockets/sockets_translate.h
-
353src/core/hle/service/ssl/ssl.cpp
-
45src/core/hle/service/ssl/ssl_backend.h
-
16src/core/hle/service/ssl/ssl_backend_none.cpp
-
351src/core/hle/service/ssl/ssl_backend_openssl.cpp
-
543src/core/hle/service/ssl/ssl_backend_schannel.cpp
-
219src/core/hle/service/ssl/ssl_backend_securetransport.cpp
-
286src/core/internal_network/network.cpp
-
36src/core/internal_network/network.h
-
22src/core/internal_network/socket_proxy.cpp
-
8src/core/internal_network/socket_proxy.h
-
16src/core/internal_network/sockets.h
@ -0,0 +1,45 @@ |
|||||
|
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project |
||||
|
// SPDX-License-Identifier: GPL-2.0-or-later |
||||
|
|
||||
|
#pragma once |
||||
|
|
||||
|
#include "core/hle/result.h" |
||||
|
|
||||
|
#include "common/common_types.h" |
||||
|
|
||||
|
#include <memory> |
||||
|
#include <span> |
||||
|
#include <string> |
||||
|
#include <vector> |
||||
|
|
||||
|
namespace Network { |
||||
|
class SocketBase; |
||||
|
} |
||||
|
|
||||
|
namespace Service::SSL { |
||||
|
|
||||
|
constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103}; |
||||
|
constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106}; |
||||
|
constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205}; |
||||
|
constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up |
||||
|
|
||||
|
// ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake, |
||||
|
// with no way in the latter case to distinguish whether the client should poll |
||||
|
// for read or write. The one official client I've seen handles this by always |
||||
|
// polling for read (with a timeout). |
||||
|
constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204}; |
||||
|
|
||||
|
class SSLConnectionBackend { |
||||
|
public: |
||||
|
virtual ~SSLConnectionBackend() {} |
||||
|
virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0; |
||||
|
virtual Result SetHostName(const std::string& hostname) = 0; |
||||
|
virtual Result DoHandshake() = 0; |
||||
|
virtual ResultVal<size_t> Read(std::span<u8> data) = 0; |
||||
|
virtual ResultVal<size_t> Write(std::span<const u8> data) = 0; |
||||
|
virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0; |
||||
|
}; |
||||
|
|
||||
|
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend(); |
||||
|
|
||||
|
} // namespace Service::SSL |
||||
@ -0,0 +1,16 @@ |
|||||
|
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
|
||||
|
// SPDX-License-Identifier: GPL-2.0-or-later
|
||||
|
|
||||
|
#include "core/hle/service/ssl/ssl_backend.h"
|
||||
|
|
||||
|
#include "common/logging/log.h"
|
||||
|
|
||||
|
namespace Service::SSL { |
||||
|
|
||||
|
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { |
||||
|
LOG_ERROR(Service_SSL, |
||||
|
"Can't create SSL connection because no SSL backend is available on this platform"); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
|
||||
|
} // namespace Service::SSL
|
||||
@ -0,0 +1,351 @@ |
|||||
|
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
|
||||
|
// SPDX-License-Identifier: GPL-2.0-or-later
|
||||
|
|
||||
|
#include "core/hle/service/ssl/ssl_backend.h"
|
||||
|
#include "core/internal_network/network.h"
|
||||
|
#include "core/internal_network/sockets.h"
|
||||
|
|
||||
|
#include "common/fs/file.h"
|
||||
|
#include "common/hex_util.h"
|
||||
|
#include "common/string_util.h"
|
||||
|
|
||||
|
#include <mutex>
|
||||
|
|
||||
|
#include <openssl/bio.h>
|
||||
|
#include <openssl/err.h>
|
||||
|
#include <openssl/ssl.h>
|
||||
|
#include <openssl/x509.h>
|
||||
|
|
||||
|
using namespace Common::FS; |
||||
|
|
||||
|
namespace Service::SSL { |
||||
|
|
||||
|
// Import OpenSSL's `SSL` type into the namespace. This is needed because the
|
||||
|
// namespace is also named `SSL`.
|
||||
|
using ::SSL; |
||||
|
|
||||
|
namespace { |
||||
|
|
||||
|
std::once_flag one_time_init_flag; |
||||
|
bool one_time_init_success = false; |
||||
|
|
||||
|
SSL_CTX* ssl_ctx; |
||||
|
IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment
|
||||
|
BIO_METHOD* bio_meth; |
||||
|
|
||||
|
Result CheckOpenSSLErrors(); |
||||
|
void OneTimeInit(); |
||||
|
void OneTimeInitLogFile(); |
||||
|
bool OneTimeInitBIO(); |
||||
|
|
||||
|
} // namespace
|
||||
|
|
||||
|
class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend { |
||||
|
public: |
||||
|
Result Init() { |
||||
|
std::call_once(one_time_init_flag, OneTimeInit); |
||||
|
|
||||
|
if (!one_time_init_success) { |
||||
|
LOG_ERROR(Service_SSL, |
||||
|
"Can't create SSL connection because OpenSSL one-time initialization failed"); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
|
||||
|
ssl = SSL_new(ssl_ctx); |
||||
|
if (!ssl) { |
||||
|
LOG_ERROR(Service_SSL, "SSL_new failed"); |
||||
|
return CheckOpenSSLErrors(); |
||||
|
} |
||||
|
|
||||
|
SSL_set_connect_state(ssl); |
||||
|
|
||||
|
bio = BIO_new(bio_meth); |
||||
|
if (!bio) { |
||||
|
LOG_ERROR(Service_SSL, "BIO_new failed"); |
||||
|
return CheckOpenSSLErrors(); |
||||
|
} |
||||
|
|
||||
|
BIO_set_data(bio, this); |
||||
|
BIO_set_init(bio, 1); |
||||
|
SSL_set_bio(ssl, bio, bio); |
||||
|
|
||||
|
return ResultSuccess; |
||||
|
} |
||||
|
|
||||
|
void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override { |
||||
|
socket = std::move(socket_in); |
||||
|
} |
||||
|
|
||||
|
Result SetHostName(const std::string& hostname) override { |
||||
|
if (!SSL_set1_host(ssl, hostname.c_str())) { // hostname for verification
|
||||
|
LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname); |
||||
|
return CheckOpenSSLErrors(); |
||||
|
} |
||||
|
if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) { // hostname for SNI
|
||||
|
LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname); |
||||
|
return CheckOpenSSLErrors(); |
||||
|
} |
||||
|
return ResultSuccess; |
||||
|
} |
||||
|
|
||||
|
Result DoHandshake() override { |
||||
|
SSL_set_verify_result(ssl, X509_V_OK); |
||||
|
const int ret = SSL_do_handshake(ssl); |
||||
|
const long verify_result = SSL_get_verify_result(ssl); |
||||
|
if (verify_result != X509_V_OK) { |
||||
|
LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}", |
||||
|
X509_verify_cert_error_string(verify_result)); |
||||
|
return CheckOpenSSLErrors(); |
||||
|
} |
||||
|
if (ret <= 0) { |
||||
|
const int ssl_err = SSL_get_error(ssl, ret); |
||||
|
if (ssl_err == SSL_ERROR_ZERO_RETURN || |
||||
|
(ssl_err == SSL_ERROR_SYSCALL && got_read_eof)) { |
||||
|
LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
} |
||||
|
return HandleReturn("SSL_do_handshake", 0, ret).Code(); |
||||
|
} |
||||
|
|
||||
|
ResultVal<size_t> Read(std::span<u8> data) override { |
||||
|
size_t actual; |
||||
|
const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual); |
||||
|
return HandleReturn("SSL_read_ex", actual, ret); |
||||
|
} |
||||
|
|
||||
|
ResultVal<size_t> Write(std::span<const u8> data) override { |
||||
|
size_t actual; |
||||
|
const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual); |
||||
|
return HandleReturn("SSL_write_ex", actual, ret); |
||||
|
} |
||||
|
|
||||
|
ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) { |
||||
|
const int ssl_err = SSL_get_error(ssl, ret); |
||||
|
CheckOpenSSLErrors(); |
||||
|
switch (ssl_err) { |
||||
|
case SSL_ERROR_NONE: |
||||
|
return actual; |
||||
|
case SSL_ERROR_ZERO_RETURN: |
||||
|
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what); |
||||
|
// DoHandshake special-cases this, but for Read and Write:
|
||||
|
return size_t(0); |
||||
|
case SSL_ERROR_WANT_READ: |
||||
|
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what); |
||||
|
return ResultWouldBlock; |
||||
|
case SSL_ERROR_WANT_WRITE: |
||||
|
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what); |
||||
|
return ResultWouldBlock; |
||||
|
default: |
||||
|
if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) { |
||||
|
LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what); |
||||
|
return size_t(0); |
||||
|
} |
||||
|
LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { |
||||
|
STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl); |
||||
|
if (!chain) { |
||||
|
LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr"); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
std::vector<std::vector<u8>> ret; |
||||
|
int count = sk_X509_num(chain); |
||||
|
ASSERT(count >= 0); |
||||
|
for (int i = 0; i < count; i++) { |
||||
|
X509* x509 = sk_X509_value(chain, i); |
||||
|
ASSERT_OR_EXECUTE(x509 != nullptr, { continue; }); |
||||
|
unsigned char* buf = nullptr; |
||||
|
int len = i2d_X509(x509, &buf); |
||||
|
ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; }); |
||||
|
ret.emplace_back(buf, buf + len); |
||||
|
OPENSSL_free(buf); |
||||
|
} |
||||
|
return ret; |
||||
|
} |
||||
|
|
||||
|
~SSLConnectionBackendOpenSSL() { |
||||
|
// these are null-tolerant:
|
||||
|
SSL_free(ssl); |
||||
|
BIO_free(bio); |
||||
|
} |
||||
|
|
||||
|
static void KeyLogCallback(const SSL* ssl, const char* line) { |
||||
|
std::string str(line); |
||||
|
str.push_back('\n'); |
||||
|
// Do this in a single WriteString for atomicity if multiple instances
|
||||
|
// are running on different threads (though that can't currently
|
||||
|
// happen).
|
||||
|
if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) { |
||||
|
LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE"); |
||||
|
} |
||||
|
LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line); |
||||
|
} |
||||
|
|
||||
|
static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) { |
||||
|
auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); |
||||
|
ASSERT_OR_EXECUTE_MSG( |
||||
|
self->socket, { return 0; }, "OpenSSL asked to send but we have no socket"); |
||||
|
BIO_clear_retry_flags(bio); |
||||
|
auto [actual, err] = self->socket->Send({reinterpret_cast<const u8*>(buf), len}, 0); |
||||
|
switch (err) { |
||||
|
case Network::Errno::SUCCESS: |
||||
|
*actual_p = actual; |
||||
|
return 1; |
||||
|
case Network::Errno::AGAIN: |
||||
|
BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY); |
||||
|
return 0; |
||||
|
default: |
||||
|
LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); |
||||
|
return -1; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) { |
||||
|
auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); |
||||
|
ASSERT_OR_EXECUTE_MSG( |
||||
|
self->socket, { return 0; }, "OpenSSL asked to recv but we have no socket"); |
||||
|
BIO_clear_retry_flags(bio); |
||||
|
auto [actual, err] = self->socket->Recv(0, {reinterpret_cast<u8*>(buf), len}); |
||||
|
switch (err) { |
||||
|
case Network::Errno::SUCCESS: |
||||
|
*actual_p = actual; |
||||
|
if (actual == 0) { |
||||
|
self->got_read_eof = true; |
||||
|
} |
||||
|
return actual ? 1 : 0; |
||||
|
case Network::Errno::AGAIN: |
||||
|
BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY); |
||||
|
return 0; |
||||
|
default: |
||||
|
LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); |
||||
|
return -1; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
static long CtrlCallback(BIO* bio, int cmd, long l_arg, void* p_arg) { |
||||
|
switch (cmd) { |
||||
|
case BIO_CTRL_FLUSH: |
||||
|
// Nothing to flush.
|
||||
|
return 1; |
||||
|
case BIO_CTRL_PUSH: |
||||
|
case BIO_CTRL_POP: |
||||
|
#ifdef BIO_CTRL_GET_KTLS_SEND
|
||||
|
case BIO_CTRL_GET_KTLS_SEND: |
||||
|
case BIO_CTRL_GET_KTLS_RECV: |
||||
|
#endif
|
||||
|
// We don't support these operations, but don't bother logging them
|
||||
|
// as they're nothing unusual.
|
||||
|
return 0; |
||||
|
default: |
||||
|
LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, l_arg, p_arg); |
||||
|
return 0; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
SSL* ssl = nullptr; |
||||
|
BIO* bio = nullptr; |
||||
|
bool got_read_eof = false; |
||||
|
|
||||
|
std::shared_ptr<Network::SocketBase> socket; |
||||
|
}; |
||||
|
|
||||
|
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { |
||||
|
auto conn = std::make_unique<SSLConnectionBackendOpenSSL>(); |
||||
|
const Result res = conn->Init(); |
||||
|
if (res.IsFailure()) { |
||||
|
return res; |
||||
|
} |
||||
|
return conn; |
||||
|
} |
||||
|
|
||||
|
namespace { |
||||
|
|
||||
|
Result CheckOpenSSLErrors() { |
||||
|
unsigned long rc; |
||||
|
const char* file; |
||||
|
int line; |
||||
|
const char* func; |
||||
|
const char* data; |
||||
|
int flags; |
||||
|
#if OPENSSL_VERSION_NUMBER >= 0x30000000L
|
||||
|
while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags))) |
||||
|
#else
|
||||
|
// Can't get function names from OpenSSL on this version, so use mine:
|
||||
|
func = __func__; |
||||
|
while ((rc = ERR_get_error_line_data(&file, &line, &data, &flags))) |
||||
|
#endif
|
||||
|
{ |
||||
|
std::string msg; |
||||
|
msg.resize(1024, '\0'); |
||||
|
ERR_error_string_n(rc, msg.data(), msg.size()); |
||||
|
msg.resize(strlen(msg.data()), '\0'); |
||||
|
if (flags & ERR_TXT_STRING) { |
||||
|
msg.append(" | "); |
||||
|
msg.append(data); |
||||
|
} |
||||
|
Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error, |
||||
|
Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}", |
||||
|
msg); |
||||
|
} |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
|
||||
|
void OneTimeInit() { |
||||
|
ssl_ctx = SSL_CTX_new(TLS_client_method()); |
||||
|
if (!ssl_ctx) { |
||||
|
LOG_ERROR(Service_SSL, "SSL_CTX_new failed"); |
||||
|
CheckOpenSSLErrors(); |
||||
|
return; |
||||
|
} |
||||
|
|
||||
|
SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr); |
||||
|
|
||||
|
if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) { |
||||
|
LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed"); |
||||
|
CheckOpenSSLErrors(); |
||||
|
return; |
||||
|
} |
||||
|
|
||||
|
OneTimeInitLogFile(); |
||||
|
|
||||
|
if (!OneTimeInitBIO()) { |
||||
|
return; |
||||
|
} |
||||
|
|
||||
|
one_time_init_success = true; |
||||
|
} |
||||
|
|
||||
|
void OneTimeInitLogFile() { |
||||
|
const char* logfile = getenv("SSLKEYLOGFILE"); |
||||
|
if (logfile) { |
||||
|
key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile, |
||||
|
FileShareFlag::ShareWriteOnly); |
||||
|
if (key_log_file.IsOpen()) { |
||||
|
SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback); |
||||
|
} else { |
||||
|
LOG_CRITICAL(Service_SSL, |
||||
|
"SSLKEYLOGFILE was set but file could not be opened; not logging keys!"); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
bool OneTimeInitBIO() { |
||||
|
bio_meth = |
||||
|
BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL"); |
||||
|
if (!bio_meth || |
||||
|
!BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) || |
||||
|
!BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) || |
||||
|
!BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) { |
||||
|
LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD"); |
||||
|
return false; |
||||
|
} |
||||
|
return true; |
||||
|
} |
||||
|
|
||||
|
} // namespace
|
||||
|
|
||||
|
} // namespace Service::SSL
|
||||
@ -0,0 +1,543 @@ |
|||||
|
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
|
||||
|
// SPDX-License-Identifier: GPL-2.0-or-later
|
||||
|
|
||||
|
#include "core/hle/service/ssl/ssl_backend.h"
|
||||
|
#include "core/internal_network/network.h"
|
||||
|
#include "core/internal_network/sockets.h"
|
||||
|
|
||||
|
#include "common/error.h"
|
||||
|
#include "common/fs/file.h"
|
||||
|
#include "common/hex_util.h"
|
||||
|
#include "common/string_util.h"
|
||||
|
|
||||
|
#include <mutex>
|
||||
|
|
||||
|
namespace { |
||||
|
|
||||
|
// These includes are inside the namespace to avoid a conflict on MinGW where
|
||||
|
// the headers define an enum containing Network and Service as enumerators
|
||||
|
// (which clash with the correspondingly named namespaces).
|
||||
|
#define SECURITY_WIN32
|
||||
|
#include <schnlsp.h>
|
||||
|
#include <security.h>
|
||||
|
|
||||
|
std::once_flag one_time_init_flag; |
||||
|
bool one_time_init_success = false; |
||||
|
|
||||
|
SCHANNEL_CRED schannel_cred{}; |
||||
|
CredHandle cred_handle; |
||||
|
|
||||
|
static void OneTimeInit() { |
||||
|
schannel_cred.dwVersion = SCHANNEL_CRED_VERSION; |
||||
|
schannel_cred.dwFlags = |
||||
|
SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols
|
||||
|
SCH_CRED_AUTO_CRED_VALIDATION | // validate certs
|
||||
|
SCH_CRED_NO_DEFAULT_CREDS; // don't automatically present a client certificate
|
||||
|
// ^ I'm assuming that nobody would want to connect Yuzu to a
|
||||
|
// service that requires some OS-provided corporate client
|
||||
|
// certificate, and presenting one to some arbitrary server
|
||||
|
// might be a privacy concern? Who knows, though.
|
||||
|
|
||||
|
const SECURITY_STATUS ret = |
||||
|
AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND, |
||||
|
nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr); |
||||
|
if (ret != SEC_E_OK) { |
||||
|
// SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString.
|
||||
|
LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}", |
||||
|
Common::NativeErrorToString(ret)); |
||||
|
return; |
||||
|
} |
||||
|
|
||||
|
if (getenv("SSLKEYLOGFILE")) { |
||||
|
LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting " |
||||
|
"keys; not logging keys!"); |
||||
|
// Not fatal.
|
||||
|
} |
||||
|
|
||||
|
one_time_init_success = true; |
||||
|
} |
||||
|
|
||||
|
} // namespace
|
||||
|
|
||||
|
namespace Service::SSL { |
||||
|
|
||||
|
class SSLConnectionBackendSchannel final : public SSLConnectionBackend { |
||||
|
public: |
||||
|
Result Init() { |
||||
|
std::call_once(one_time_init_flag, OneTimeInit); |
||||
|
|
||||
|
if (!one_time_init_success) { |
||||
|
LOG_ERROR( |
||||
|
Service_SSL, |
||||
|
"Can't create SSL connection because Schannel one-time initialization failed"); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
|
||||
|
return ResultSuccess; |
||||
|
} |
||||
|
|
||||
|
void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override { |
||||
|
socket = std::move(socket_in); |
||||
|
} |
||||
|
|
||||
|
Result SetHostName(const std::string& hostname_in) override { |
||||
|
hostname = hostname_in; |
||||
|
return ResultSuccess; |
||||
|
} |
||||
|
|
||||
|
Result DoHandshake() override { |
||||
|
while (1) { |
||||
|
Result r; |
||||
|
switch (handshake_state) { |
||||
|
case HandshakeState::Initial: |
||||
|
if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || |
||||
|
(r = CallInitializeSecurityContext()) != ResultSuccess) { |
||||
|
return r; |
||||
|
} |
||||
|
// CallInitializeSecurityContext updated `handshake_state`.
|
||||
|
continue; |
||||
|
case HandshakeState::ContinueNeeded: |
||||
|
case HandshakeState::IncompleteMessage: |
||||
|
if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || |
||||
|
(r = FillCiphertextReadBuf()) != ResultSuccess) { |
||||
|
return r; |
||||
|
} |
||||
|
if (ciphertext_read_buf.empty()) { |
||||
|
LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
if ((r = CallInitializeSecurityContext()) != ResultSuccess) { |
||||
|
return r; |
||||
|
} |
||||
|
// CallInitializeSecurityContext updated `handshake_state`.
|
||||
|
continue; |
||||
|
case HandshakeState::DoneAfterFlush: |
||||
|
if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) { |
||||
|
return r; |
||||
|
} |
||||
|
handshake_state = HandshakeState::Connected; |
||||
|
return ResultSuccess; |
||||
|
case HandshakeState::Connected: |
||||
|
LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook"); |
||||
|
return ResultInternalError; |
||||
|
case HandshakeState::Error: |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
Result FillCiphertextReadBuf() { |
||||
|
const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096; |
||||
|
read_buf_fill_size = 0; |
||||
|
// This unnecessarily zeroes the buffer; oh well.
|
||||
|
const size_t offset = ciphertext_read_buf.size(); |
||||
|
ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; }); |
||||
|
ciphertext_read_buf.resize(offset + fill_size, 0); |
||||
|
const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size); |
||||
|
const auto [actual, err] = socket->Recv(0, read_span); |
||||
|
switch (err) { |
||||
|
case Network::Errno::SUCCESS: |
||||
|
ASSERT(static_cast<size_t>(actual) <= fill_size); |
||||
|
ciphertext_read_buf.resize(offset + actual); |
||||
|
return ResultSuccess; |
||||
|
case Network::Errno::AGAIN: |
||||
|
ciphertext_read_buf.resize(offset); |
||||
|
return ResultWouldBlock; |
||||
|
default: |
||||
|
ciphertext_read_buf.resize(offset); |
||||
|
LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Returns success if the write buffer has been completely emptied.
|
||||
|
Result FlushCiphertextWriteBuf() { |
||||
|
while (!ciphertext_write_buf.empty()) { |
||||
|
const auto [actual, err] = socket->Send(ciphertext_write_buf, 0); |
||||
|
switch (err) { |
||||
|
case Network::Errno::SUCCESS: |
||||
|
ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf.size()); |
||||
|
ciphertext_write_buf.erase(ciphertext_write_buf.begin(), |
||||
|
ciphertext_write_buf.begin() + actual); |
||||
|
break; |
||||
|
case Network::Errno::AGAIN: |
||||
|
return ResultWouldBlock; |
||||
|
default: |
||||
|
LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
} |
||||
|
return ResultSuccess; |
||||
|
} |
||||
|
|
||||
|
Result CallInitializeSecurityContext() { |
||||
|
const unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | |
||||
|
ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT | |
||||
|
ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM | |
||||
|
ISC_REQ_USE_SUPPLIED_CREDS; |
||||
|
unsigned long attr; |
||||
|
// https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel
|
||||
|
std::array<SecBuffer, 2> input_buffers{{ |
||||
|
// only used if `initial_call_done`
|
||||
|
{ |
||||
|
// [0]
|
||||
|
.cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()), |
||||
|
.BufferType = SECBUFFER_TOKEN, |
||||
|
.pvBuffer = ciphertext_read_buf.data(), |
||||
|
}, |
||||
|
{ |
||||
|
// [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is
|
||||
|
// returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the
|
||||
|
// whole buffer wasn't used)
|
||||
|
.cbBuffer = 0, |
||||
|
.BufferType = SECBUFFER_EMPTY, |
||||
|
.pvBuffer = nullptr, |
||||
|
}, |
||||
|
}}; |
||||
|
std::array<SecBuffer, 2> output_buffers{{ |
||||
|
{ |
||||
|
.cbBuffer = 0, |
||||
|
.BufferType = SECBUFFER_TOKEN, |
||||
|
.pvBuffer = nullptr, |
||||
|
}, // [0]
|
||||
|
{ |
||||
|
.cbBuffer = 0, |
||||
|
.BufferType = SECBUFFER_ALERT, |
||||
|
.pvBuffer = nullptr, |
||||
|
}, // [1]
|
||||
|
}}; |
||||
|
SecBufferDesc input_desc{ |
||||
|
.ulVersion = SECBUFFER_VERSION, |
||||
|
.cBuffers = static_cast<unsigned long>(input_buffers.size()), |
||||
|
.pBuffers = input_buffers.data(), |
||||
|
}; |
||||
|
SecBufferDesc output_desc{ |
||||
|
.ulVersion = SECBUFFER_VERSION, |
||||
|
.cBuffers = static_cast<unsigned long>(output_buffers.size()), |
||||
|
.pBuffers = output_buffers.data(), |
||||
|
}; |
||||
|
ASSERT_OR_EXECUTE_MSG( |
||||
|
input_buffers[0].cbBuffer == ciphertext_read_buf.size(), |
||||
|
{ return ResultInternalError; }, "read buffer too large"); |
||||
|
|
||||
|
bool initial_call_done = handshake_state != HandshakeState::Initial; |
||||
|
if (initial_call_done) { |
||||
|
LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext", |
||||
|
ciphertext_read_buf.size()); |
||||
|
} |
||||
|
|
||||
|
const SECURITY_STATUS ret = |
||||
|
InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr, |
||||
|
// Caller ensured we have set a hostname:
|
||||
|
const_cast<char*>(hostname.value().c_str()), req, |
||||
|
0, // Reserved1
|
||||
|
0, // TargetDataRep not used with Schannel
|
||||
|
initial_call_done ? &input_desc : nullptr, |
||||
|
0, // Reserved2
|
||||
|
initial_call_done ? nullptr : &ctxt, &output_desc, &attr, |
||||
|
nullptr); // ptsExpiry
|
||||
|
|
||||
|
if (output_buffers[0].pvBuffer) { |
||||
|
const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer), |
||||
|
output_buffers[0].cbBuffer); |
||||
|
ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end()); |
||||
|
FreeContextBuffer(output_buffers[0].pvBuffer); |
||||
|
} |
||||
|
|
||||
|
if (output_buffers[1].pvBuffer) { |
||||
|
const std::span span(static_cast<u8*>(output_buffers[1].pvBuffer), |
||||
|
output_buffers[1].cbBuffer); |
||||
|
// The documentation doesn't explain what format this data is in.
|
||||
|
LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(), |
||||
|
Common::HexToString(span)); |
||||
|
} |
||||
|
|
||||
|
switch (ret) { |
||||
|
case SEC_I_CONTINUE_NEEDED: |
||||
|
LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED"); |
||||
|
if (input_buffers[1].BufferType == SECBUFFER_EXTRA) { |
||||
|
LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer); |
||||
|
ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size()); |
||||
|
ciphertext_read_buf.erase(ciphertext_read_buf.begin(), |
||||
|
ciphertext_read_buf.end() - input_buffers[1].cbBuffer); |
||||
|
} else { |
||||
|
ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY); |
||||
|
ciphertext_read_buf.clear(); |
||||
|
} |
||||
|
handshake_state = HandshakeState::ContinueNeeded; |
||||
|
return ResultSuccess; |
||||
|
case SEC_E_INCOMPLETE_MESSAGE: |
||||
|
LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE"); |
||||
|
ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING); |
||||
|
read_buf_fill_size = input_buffers[1].cbBuffer; |
||||
|
handshake_state = HandshakeState::IncompleteMessage; |
||||
|
return ResultSuccess; |
||||
|
case SEC_E_OK: |
||||
|
LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK"); |
||||
|
ciphertext_read_buf.clear(); |
||||
|
handshake_state = HandshakeState::DoneAfterFlush; |
||||
|
return GrabStreamSizes(); |
||||
|
default: |
||||
|
LOG_ERROR(Service_SSL, |
||||
|
"InitializeSecurityContext failed (probably certificate/protocol issue): {}", |
||||
|
Common::NativeErrorToString(ret)); |
||||
|
handshake_state = HandshakeState::Error; |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
Result GrabStreamSizes() { |
||||
|
const SECURITY_STATUS ret = |
||||
|
QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes); |
||||
|
if (ret != SEC_E_OK) { |
||||
|
LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}", |
||||
|
Common::NativeErrorToString(ret)); |
||||
|
handshake_state = HandshakeState::Error; |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
return ResultSuccess; |
||||
|
} |
||||
|
|
||||
|
ResultVal<size_t> Read(std::span<u8> data) override { |
||||
|
if (handshake_state != HandshakeState::Connected) { |
||||
|
LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake"); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
if (data.size() == 0 || got_read_eof) { |
||||
|
return size_t(0); |
||||
|
} |
||||
|
while (1) { |
||||
|
if (!cleartext_read_buf.empty()) { |
||||
|
const size_t read_size = std::min(cleartext_read_buf.size(), data.size()); |
||||
|
std::memcpy(data.data(), cleartext_read_buf.data(), read_size); |
||||
|
cleartext_read_buf.erase(cleartext_read_buf.begin(), |
||||
|
cleartext_read_buf.begin() + read_size); |
||||
|
return read_size; |
||||
|
} |
||||
|
if (!ciphertext_read_buf.empty()) { |
||||
|
SecBuffer empty{ |
||||
|
.cbBuffer = 0, |
||||
|
.BufferType = SECBUFFER_EMPTY, |
||||
|
.pvBuffer = nullptr, |
||||
|
}; |
||||
|
std::array<SecBuffer, 5> buffers{{ |
||||
|
{ |
||||
|
.cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()), |
||||
|
.BufferType = SECBUFFER_DATA, |
||||
|
.pvBuffer = ciphertext_read_buf.data(), |
||||
|
}, |
||||
|
empty, |
||||
|
empty, |
||||
|
empty, |
||||
|
}}; |
||||
|
ASSERT_OR_EXECUTE_MSG( |
||||
|
buffers[0].cbBuffer == ciphertext_read_buf.size(), |
||||
|
{ return ResultInternalError; }, "read buffer too large"); |
||||
|
SecBufferDesc desc{ |
||||
|
.ulVersion = SECBUFFER_VERSION, |
||||
|
.cBuffers = static_cast<unsigned long>(buffers.size()), |
||||
|
.pBuffers = buffers.data(), |
||||
|
}; |
||||
|
SECURITY_STATUS ret = |
||||
|
DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr); |
||||
|
switch (ret) { |
||||
|
case SEC_E_OK: |
||||
|
ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER, |
||||
|
{ return ResultInternalError; }); |
||||
|
ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA, |
||||
|
{ return ResultInternalError; }); |
||||
|
ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER, |
||||
|
{ return ResultInternalError; }); |
||||
|
cleartext_read_buf.assign(static_cast<u8*>(buffers[1].pvBuffer), |
||||
|
static_cast<u8*>(buffers[1].pvBuffer) + |
||||
|
buffers[1].cbBuffer); |
||||
|
if (buffers[3].BufferType == SECBUFFER_EXTRA) { |
||||
|
ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size()); |
||||
|
ciphertext_read_buf.erase(ciphertext_read_buf.begin(), |
||||
|
ciphertext_read_buf.end() - buffers[3].cbBuffer); |
||||
|
} else { |
||||
|
ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY); |
||||
|
ciphertext_read_buf.clear(); |
||||
|
} |
||||
|
continue; |
||||
|
case SEC_E_INCOMPLETE_MESSAGE: |
||||
|
break; |
||||
|
case SEC_I_CONTEXT_EXPIRED: |
||||
|
// Server hung up by sending close_notify.
|
||||
|
got_read_eof = true; |
||||
|
return size_t(0); |
||||
|
default: |
||||
|
LOG_ERROR(Service_SSL, "DecryptMessage failed: {}", |
||||
|
Common::NativeErrorToString(ret)); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
} |
||||
|
const Result r = FillCiphertextReadBuf(); |
||||
|
if (r != ResultSuccess) { |
||||
|
return r; |
||||
|
} |
||||
|
if (ciphertext_read_buf.empty()) { |
||||
|
got_read_eof = true; |
||||
|
return size_t(0); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
ResultVal<size_t> Write(std::span<const u8> data) override { |
||||
|
if (handshake_state != HandshakeState::Connected) { |
||||
|
LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake"); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
if (data.size() == 0) { |
||||
|
return size_t(0); |
||||
|
} |
||||
|
data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage)); |
||||
|
if (!cleartext_write_buf.empty()) { |
||||
|
// Already in the middle of a write. It wouldn't make sense to not
|
||||
|
// finish sending the entire buffer since TLS has
|
||||
|
// header/MAC/padding/etc.
|
||||
|
if (data.size() != cleartext_write_buf.size() || |
||||
|
std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) { |
||||
|
LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer"); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
return WriteAlreadyEncryptedData(); |
||||
|
} else { |
||||
|
cleartext_write_buf.assign(data.begin(), data.end()); |
||||
|
} |
||||
|
|
||||
|
std::vector<u8> header_buf(stream_sizes.cbHeader, 0); |
||||
|
std::vector<u8> tmp_data_buf = cleartext_write_buf; |
||||
|
std::vector<u8> trailer_buf(stream_sizes.cbTrailer, 0); |
||||
|
|
||||
|
std::array<SecBuffer, 3> buffers{{ |
||||
|
{ |
||||
|
.cbBuffer = stream_sizes.cbHeader, |
||||
|
.BufferType = SECBUFFER_STREAM_HEADER, |
||||
|
.pvBuffer = header_buf.data(), |
||||
|
}, |
||||
|
{ |
||||
|
.cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()), |
||||
|
.BufferType = SECBUFFER_DATA, |
||||
|
.pvBuffer = tmp_data_buf.data(), |
||||
|
}, |
||||
|
{ |
||||
|
.cbBuffer = stream_sizes.cbTrailer, |
||||
|
.BufferType = SECBUFFER_STREAM_TRAILER, |
||||
|
.pvBuffer = trailer_buf.data(), |
||||
|
}, |
||||
|
}}; |
||||
|
ASSERT_OR_EXECUTE_MSG( |
||||
|
buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; }, |
||||
|
"temp buffer too large"); |
||||
|
SecBufferDesc desc{ |
||||
|
.ulVersion = SECBUFFER_VERSION, |
||||
|
.cBuffers = static_cast<unsigned long>(buffers.size()), |
||||
|
.pBuffers = buffers.data(), |
||||
|
}; |
||||
|
|
||||
|
const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0); |
||||
|
if (ret != SEC_E_OK) { |
||||
|
LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret)); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(), |
||||
|
header_buf.end()); |
||||
|
ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(), |
||||
|
tmp_data_buf.end()); |
||||
|
ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(), |
||||
|
trailer_buf.end()); |
||||
|
return WriteAlreadyEncryptedData(); |
||||
|
} |
||||
|
|
||||
|
ResultVal<size_t> WriteAlreadyEncryptedData() { |
||||
|
const Result r = FlushCiphertextWriteBuf(); |
||||
|
if (r != ResultSuccess) { |
||||
|
return r; |
||||
|
} |
||||
|
// write buf is empty
|
||||
|
const size_t cleartext_bytes_written = cleartext_write_buf.size(); |
||||
|
cleartext_write_buf.clear(); |
||||
|
return cleartext_bytes_written; |
||||
|
} |
||||
|
|
||||
|
ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { |
||||
|
PCCERT_CONTEXT returned_cert = nullptr; |
||||
|
const SECURITY_STATUS ret = |
||||
|
QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert); |
||||
|
if (ret != SEC_E_OK) { |
||||
|
LOG_ERROR(Service_SSL, |
||||
|
"QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}", |
||||
|
Common::NativeErrorToString(ret)); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
PCCERT_CONTEXT some_cert = nullptr; |
||||
|
std::vector<std::vector<u8>> certs; |
||||
|
while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) { |
||||
|
certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded), |
||||
|
static_cast<u8*>(some_cert->pbCertEncoded) + |
||||
|
some_cert->cbCertEncoded); |
||||
|
} |
||||
|
std::reverse(certs.begin(), |
||||
|
certs.end()); // Windows returns certs in reverse order from what we want
|
||||
|
CertFreeCertificateContext(returned_cert); |
||||
|
return certs; |
||||
|
} |
||||
|
|
||||
|
~SSLConnectionBackendSchannel() { |
||||
|
if (handshake_state != HandshakeState::Initial) { |
||||
|
DeleteSecurityContext(&ctxt); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
enum class HandshakeState { |
||||
|
// Haven't called anything yet.
|
||||
|
Initial, |
||||
|
// `SEC_I_CONTINUE_NEEDED` was returned by
|
||||
|
// `InitializeSecurityContext`; must finish sending data (if any) in
|
||||
|
// the write buffer, then read at least one byte before calling
|
||||
|
// `InitializeSecurityContext` again.
|
||||
|
ContinueNeeded, |
||||
|
// `SEC_E_INCOMPLETE_MESSAGE` was returned by
|
||||
|
// `InitializeSecurityContext`; hopefully the write buffer is empty;
|
||||
|
// must read at least one byte before calling
|
||||
|
// `InitializeSecurityContext` again.
|
||||
|
IncompleteMessage, |
||||
|
// `SEC_E_OK` was returned by `InitializeSecurityContext`; must
|
||||
|
// finish sending data in the write buffer before having `DoHandshake`
|
||||
|
// report success.
|
||||
|
DoneAfterFlush, |
||||
|
// We finished the above and are now connected. At this point, writing
|
||||
|
// and reading are separate 'state machines' represented by the
|
||||
|
// nonemptiness of the ciphertext and cleartext read and write buffers.
|
||||
|
Connected, |
||||
|
// Another error was returned and we shouldn't allow initialization
|
||||
|
// to continue.
|
||||
|
Error, |
||||
|
} handshake_state = HandshakeState::Initial; |
||||
|
|
||||
|
CtxtHandle ctxt; |
||||
|
SecPkgContext_StreamSizes stream_sizes; |
||||
|
|
||||
|
std::shared_ptr<Network::SocketBase> socket; |
||||
|
std::optional<std::string> hostname; |
||||
|
|
||||
|
std::vector<u8> ciphertext_read_buf; |
||||
|
std::vector<u8> ciphertext_write_buf; |
||||
|
std::vector<u8> cleartext_read_buf; |
||||
|
std::vector<u8> cleartext_write_buf; |
||||
|
|
||||
|
bool got_read_eof = false; |
||||
|
size_t read_buf_fill_size = 0; |
||||
|
}; |
||||
|
|
||||
|
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { |
||||
|
auto conn = std::make_unique<SSLConnectionBackendSchannel>(); |
||||
|
const Result res = conn->Init(); |
||||
|
if (res.IsFailure()) { |
||||
|
return res; |
||||
|
} |
||||
|
return conn; |
||||
|
} |
||||
|
|
||||
|
} // namespace Service::SSL
|
||||
@ -0,0 +1,219 @@ |
|||||
|
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
|
||||
|
// SPDX-License-Identifier: GPL-2.0-or-later
|
||||
|
|
||||
|
#include "core/hle/service/ssl/ssl_backend.h"
|
||||
|
#include "core/internal_network/network.h"
|
||||
|
#include "core/internal_network/sockets.h"
|
||||
|
|
||||
|
#include <mutex>
|
||||
|
|
||||
|
#include <Security/SecureTransport.h>
|
||||
|
|
||||
|
// SecureTransport has been deprecated in its entirety in favor of
|
||||
|
// Network.framework, but that does not allow layering TLS on top of an
|
||||
|
// arbitrary socket.
|
||||
|
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
|
|
||||
|
namespace { |
||||
|
|
||||
|
template <typename T> |
||||
|
struct CFReleaser { |
||||
|
T ptr; |
||||
|
|
||||
|
YUZU_NON_COPYABLE(CFReleaser); |
||||
|
constexpr CFReleaser() : ptr(nullptr) {} |
||||
|
constexpr CFReleaser(T ptr) : ptr(ptr) {} |
||||
|
constexpr operator T() { |
||||
|
return ptr; |
||||
|
} |
||||
|
~CFReleaser() { |
||||
|
if (ptr) { |
||||
|
CFRelease(ptr); |
||||
|
} |
||||
|
} |
||||
|
}; |
||||
|
|
||||
|
std::string CFStringToString(CFStringRef cfstr) { |
||||
|
CFReleaser<CFDataRef> cfdata( |
||||
|
CFStringCreateExternalRepresentation(nullptr, cfstr, kCFStringEncodingUTF8, 0)); |
||||
|
ASSERT_OR_EXECUTE(cfdata, { return "???"; }); |
||||
|
return std::string(reinterpret_cast<const char*>(CFDataGetBytePtr(cfdata)), |
||||
|
CFDataGetLength(cfdata)); |
||||
|
} |
||||
|
|
||||
|
std::string OSStatusToString(OSStatus status) { |
||||
|
CFReleaser<CFStringRef> cfstr(SecCopyErrorMessageString(status, nullptr)); |
||||
|
if (!cfstr) { |
||||
|
return "[unknown error]"; |
||||
|
} |
||||
|
return CFStringToString(cfstr); |
||||
|
} |
||||
|
|
||||
|
} // namespace
|
||||
|
|
||||
|
namespace Service::SSL { |
||||
|
|
||||
|
class SSLConnectionBackendSecureTransport final : public SSLConnectionBackend { |
||||
|
public: |
||||
|
Result Init() { |
||||
|
static std::once_flag once_flag; |
||||
|
std::call_once(once_flag, []() { |
||||
|
if (getenv("SSLKEYLOGFILE")) { |
||||
|
LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but SecureTransport does not " |
||||
|
"support exporting keys; not logging keys!"); |
||||
|
// Not fatal.
|
||||
|
} |
||||
|
}); |
||||
|
|
||||
|
context.ptr = SSLCreateContext(nullptr, kSSLClientSide, kSSLStreamType); |
||||
|
if (!context) { |
||||
|
LOG_ERROR(Service_SSL, "SSLCreateContext failed"); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
|
||||
|
OSStatus status; |
||||
|
if ((status = SSLSetIOFuncs(context, ReadCallback, WriteCallback)) || |
||||
|
(status = SSLSetConnection(context, this))) { |
||||
|
LOG_ERROR(Service_SSL, "SSLContext initialization failed: {}", |
||||
|
OSStatusToString(status)); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
|
||||
|
return ResultSuccess; |
||||
|
} |
||||
|
|
||||
|
void SetSocket(std::shared_ptr<Network::SocketBase> in_socket) override { |
||||
|
socket = std::move(in_socket); |
||||
|
} |
||||
|
|
||||
|
Result SetHostName(const std::string& hostname) override { |
||||
|
OSStatus status = SSLSetPeerDomainName(context, hostname.c_str(), hostname.size()); |
||||
|
if (status) { |
||||
|
LOG_ERROR(Service_SSL, "SSLSetPeerDomainName failed: {}", OSStatusToString(status)); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
return ResultSuccess; |
||||
|
} |
||||
|
|
||||
|
Result DoHandshake() override { |
||||
|
OSStatus status = SSLHandshake(context); |
||||
|
return HandleReturn("SSLHandshake", 0, status).Code(); |
||||
|
} |
||||
|
|
||||
|
ResultVal<size_t> Read(std::span<u8> data) override { |
||||
|
size_t actual; |
||||
|
OSStatus status = SSLRead(context, data.data(), data.size(), &actual); |
||||
|
; |
||||
|
return HandleReturn("SSLRead", actual, status); |
||||
|
} |
||||
|
|
||||
|
ResultVal<size_t> Write(std::span<const u8> data) override { |
||||
|
size_t actual; |
||||
|
OSStatus status = SSLWrite(context, data.data(), data.size(), &actual); |
||||
|
; |
||||
|
return HandleReturn("SSLWrite", actual, status); |
||||
|
} |
||||
|
|
||||
|
ResultVal<size_t> HandleReturn(const char* what, size_t actual, OSStatus status) { |
||||
|
switch (status) { |
||||
|
case 0: |
||||
|
return actual; |
||||
|
case errSSLWouldBlock: |
||||
|
return ResultWouldBlock; |
||||
|
default: { |
||||
|
std::string reason; |
||||
|
if (got_read_eof) { |
||||
|
reason = "server hung up"; |
||||
|
} else { |
||||
|
reason = OSStatusToString(status); |
||||
|
} |
||||
|
LOG_ERROR(Service_SSL, "{} failed: {}", what, reason); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { |
||||
|
CFReleaser<SecTrustRef> trust; |
||||
|
OSStatus status = SSLCopyPeerTrust(context, &trust.ptr); |
||||
|
if (status) { |
||||
|
LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status)); |
||||
|
return ResultInternalError; |
||||
|
} |
||||
|
std::vector<std::vector<u8>> ret; |
||||
|
for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) { |
||||
|
SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i); |
||||
|
CFReleaser<CFDataRef> data(SecCertificateCopyData(cert)); |
||||
|
ASSERT_OR_EXECUTE(data, { return ResultInternalError; }); |
||||
|
const u8* ptr = CFDataGetBytePtr(data); |
||||
|
ret.emplace_back(ptr, ptr + CFDataGetLength(data)); |
||||
|
} |
||||
|
return ret; |
||||
|
} |
||||
|
|
||||
|
static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) { |
||||
|
return ReadOrWriteCallback(connection, data, dataLength, true); |
||||
|
} |
||||
|
|
||||
|
static OSStatus WriteCallback(SSLConnectionRef connection, const void* data, |
||||
|
size_t* dataLength) { |
||||
|
return ReadOrWriteCallback(connection, const_cast<void*>(data), dataLength, false); |
||||
|
} |
||||
|
|
||||
|
static OSStatus ReadOrWriteCallback(SSLConnectionRef connection, void* data, size_t* dataLength, |
||||
|
bool is_read) { |
||||
|
auto self = |
||||
|
static_cast<SSLConnectionBackendSecureTransport*>(const_cast<void*>(connection)); |
||||
|
ASSERT_OR_EXECUTE_MSG( |
||||
|
self->socket, { return 0; }, "SecureTransport asked to {} but we have no socket", |
||||
|
is_read ? "read" : "write"); |
||||
|
|
||||
|
// SecureTransport callbacks (unlike OpenSSL BIO callbacks) are
|
||||
|
// expected to read/write the full requested dataLength or return an
|
||||
|
// error, so we have to add a loop ourselves.
|
||||
|
size_t requested_len = *dataLength; |
||||
|
size_t offset = 0; |
||||
|
while (offset < requested_len) { |
||||
|
std::span cur(reinterpret_cast<u8*>(data) + offset, requested_len - offset); |
||||
|
auto [actual, err] = is_read ? self->socket->Recv(0, cur) : self->socket->Send(cur, 0); |
||||
|
LOG_CRITICAL(Service_SSL, "op={}, offset={} actual={}/{} err={}", is_read, offset, |
||||
|
actual, cur.size(), static_cast<s32>(err)); |
||||
|
switch (err) { |
||||
|
case Network::Errno::SUCCESS: |
||||
|
offset += actual; |
||||
|
if (actual == 0) { |
||||
|
ASSERT(is_read); |
||||
|
self->got_read_eof = true; |
||||
|
return errSecEndOfData; |
||||
|
} |
||||
|
break; |
||||
|
case Network::Errno::AGAIN: |
||||
|
*dataLength = offset; |
||||
|
return errSSLWouldBlock; |
||||
|
default: |
||||
|
LOG_ERROR(Service_SSL, "Socket {} returned Network::Errno {}", |
||||
|
is_read ? "recv" : "send", err); |
||||
|
return errSecIO; |
||||
|
} |
||||
|
} |
||||
|
ASSERT(offset == requested_len); |
||||
|
return 0; |
||||
|
} |
||||
|
|
||||
|
private: |
||||
|
CFReleaser<SSLContextRef> context = nullptr; |
||||
|
bool got_read_eof = false; |
||||
|
|
||||
|
std::shared_ptr<Network::SocketBase> socket; |
||||
|
}; |
||||
|
|
||||
|
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { |
||||
|
auto conn = std::make_unique<SSLConnectionBackendSecureTransport>(); |
||||
|
const Result res = conn->Init(); |
||||
|
if (res.IsFailure()) { |
||||
|
return res; |
||||
|
} |
||||
|
return conn; |
||||
|
} |
||||
|
|
||||
|
} // namespace Service::SSL
|
||||
Write
Preview
Loading…
Cancel
Save
Reference in new issue