1 changed files with 219 additions and 0 deletions
@ -0,0 +1,219 @@ |
|||
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
|
|||
// SPDX-License-Identifier: GPL-2.0-or-later
|
|||
|
|||
#include "core/hle/service/ssl/ssl_backend.h"
|
|||
#include "core/internal_network/network.h"
|
|||
#include "core/internal_network/sockets.h"
|
|||
|
|||
#include <mutex>
|
|||
|
|||
#include <Security/SecureTransport.h>
|
|||
|
|||
// SecureTransport has been deprecated in its entirety in favor of
|
|||
// Network.framework, but that does not allow layering TLS on top of an
|
|||
// arbitrary socket.
|
|||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
|||
|
|||
namespace { |
|||
|
|||
template <typename T> |
|||
struct CFReleaser { |
|||
T ptr; |
|||
|
|||
YUZU_NON_COPYABLE(CFReleaser); |
|||
constexpr CFReleaser() : ptr(nullptr) {} |
|||
constexpr CFReleaser(T ptr) : ptr(ptr) {} |
|||
constexpr operator T() { |
|||
return ptr; |
|||
} |
|||
~CFReleaser() { |
|||
if (ptr) { |
|||
CFRelease(ptr); |
|||
} |
|||
} |
|||
}; |
|||
|
|||
std::string CFStringToString(CFStringRef cfstr) { |
|||
CFReleaser<CFDataRef> cfdata( |
|||
CFStringCreateExternalRepresentation(nullptr, cfstr, kCFStringEncodingUTF8, 0)); |
|||
ASSERT_OR_EXECUTE(cfdata, { return "???"; }); |
|||
return std::string(reinterpret_cast<const char*>(CFDataGetBytePtr(cfdata)), |
|||
CFDataGetLength(cfdata)); |
|||
} |
|||
|
|||
std::string OSStatusToString(OSStatus status) { |
|||
CFReleaser<CFStringRef> cfstr(SecCopyErrorMessageString(status, nullptr)); |
|||
if (!cfstr) { |
|||
return "[unknown error]"; |
|||
} |
|||
return CFStringToString(cfstr); |
|||
} |
|||
|
|||
} // namespace
|
|||
|
|||
namespace Service::SSL { |
|||
|
|||
class SSLConnectionBackendSecureTransport final : public SSLConnectionBackend { |
|||
public: |
|||
Result Init() { |
|||
static std::once_flag once_flag; |
|||
std::call_once(once_flag, []() { |
|||
if (getenv("SSLKEYLOGFILE")) { |
|||
LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but SecureTransport does not " |
|||
"support exporting keys; not logging keys!"); |
|||
// Not fatal.
|
|||
} |
|||
}); |
|||
|
|||
context.ptr = SSLCreateContext(nullptr, kSSLClientSide, kSSLStreamType); |
|||
if (!context) { |
|||
LOG_ERROR(Service_SSL, "SSLCreateContext failed"); |
|||
return ResultInternalError; |
|||
} |
|||
|
|||
OSStatus status; |
|||
if ((status = SSLSetIOFuncs(context, ReadCallback, WriteCallback)) || |
|||
(status = SSLSetConnection(context, this))) { |
|||
LOG_ERROR(Service_SSL, "SSLContext initialization failed: {}", |
|||
OSStatusToString(status)); |
|||
return ResultInternalError; |
|||
} |
|||
|
|||
return ResultSuccess; |
|||
} |
|||
|
|||
void SetSocket(std::shared_ptr<Network::SocketBase> in_socket) override { |
|||
socket = std::move(in_socket); |
|||
} |
|||
|
|||
Result SetHostName(const std::string& hostname) override { |
|||
OSStatus status = SSLSetPeerDomainName(context, hostname.c_str(), hostname.size()); |
|||
if (status) { |
|||
LOG_ERROR(Service_SSL, "SSLSetPeerDomainName failed: {}", OSStatusToString(status)); |
|||
return ResultInternalError; |
|||
} |
|||
return ResultSuccess; |
|||
} |
|||
|
|||
Result DoHandshake() override { |
|||
OSStatus status = SSLHandshake(context); |
|||
return HandleReturn("SSLHandshake", 0, status).Code(); |
|||
} |
|||
|
|||
ResultVal<size_t> Read(std::span<u8> data) override { |
|||
size_t actual; |
|||
OSStatus status = SSLRead(context, data.data(), data.size(), &actual); |
|||
; |
|||
return HandleReturn("SSLRead", actual, status); |
|||
} |
|||
|
|||
ResultVal<size_t> Write(std::span<const u8> data) override { |
|||
size_t actual; |
|||
OSStatus status = SSLWrite(context, data.data(), data.size(), &actual); |
|||
; |
|||
return HandleReturn("SSLWrite", actual, status); |
|||
} |
|||
|
|||
ResultVal<size_t> HandleReturn(const char* what, size_t actual, OSStatus status) { |
|||
switch (status) { |
|||
case 0: |
|||
return actual; |
|||
case errSSLWouldBlock: |
|||
return ResultWouldBlock; |
|||
default: { |
|||
std::string reason; |
|||
if (got_read_eof) { |
|||
reason = "server hung up"; |
|||
} else { |
|||
reason = OSStatusToString(status); |
|||
} |
|||
LOG_ERROR(Service_SSL, "{} failed: {}", what, reason); |
|||
return ResultInternalError; |
|||
} |
|||
} |
|||
} |
|||
|
|||
ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { |
|||
CFReleaser<SecTrustRef> trust; |
|||
OSStatus status = SSLCopyPeerTrust(context, &trust.ptr); |
|||
if (status) { |
|||
LOG_ERROR(Service_SSL, "SSLCopyPeerTrust failed: {}", OSStatusToString(status)); |
|||
return ResultInternalError; |
|||
} |
|||
std::vector<std::vector<u8>> ret; |
|||
for (CFIndex i = 0, count = SecTrustGetCertificateCount(trust); i < count; i++) { |
|||
SecCertificateRef cert = SecTrustGetCertificateAtIndex(trust, i); |
|||
CFReleaser<CFDataRef> data(SecCertificateCopyData(cert)); |
|||
ASSERT_OR_EXECUTE(data, { return ResultInternalError; }); |
|||
const u8* ptr = CFDataGetBytePtr(data); |
|||
ret.emplace_back(ptr, ptr + CFDataGetLength(data)); |
|||
} |
|||
return ret; |
|||
} |
|||
|
|||
static OSStatus ReadCallback(SSLConnectionRef connection, void* data, size_t* dataLength) { |
|||
return ReadOrWriteCallback(connection, data, dataLength, true); |
|||
} |
|||
|
|||
static OSStatus WriteCallback(SSLConnectionRef connection, const void* data, |
|||
size_t* dataLength) { |
|||
return ReadOrWriteCallback(connection, const_cast<void*>(data), dataLength, false); |
|||
} |
|||
|
|||
static OSStatus ReadOrWriteCallback(SSLConnectionRef connection, void* data, size_t* dataLength, |
|||
bool is_read) { |
|||
auto self = |
|||
static_cast<SSLConnectionBackendSecureTransport*>(const_cast<void*>(connection)); |
|||
ASSERT_OR_EXECUTE_MSG( |
|||
self->socket, { return 0; }, "SecureTransport asked to {} but we have no socket", |
|||
is_read ? "read" : "write"); |
|||
|
|||
// SecureTransport callbacks (unlike OpenSSL BIO callbacks) are
|
|||
// expected to read/write the full requested dataLength or return an
|
|||
// error, so we have to add a loop ourselves.
|
|||
size_t requested_len = *dataLength; |
|||
size_t offset = 0; |
|||
while (offset < requested_len) { |
|||
std::span cur(reinterpret_cast<u8*>(data) + offset, requested_len - offset); |
|||
auto [actual, err] = is_read ? self->socket->Recv(0, cur) : self->socket->Send(cur, 0); |
|||
LOG_CRITICAL(Service_SSL, "op={}, offset={} actual={}/{} err={}", is_read, offset, |
|||
actual, cur.size(), static_cast<s32>(err)); |
|||
switch (err) { |
|||
case Network::Errno::SUCCESS: |
|||
offset += actual; |
|||
if (actual == 0) { |
|||
ASSERT(is_read); |
|||
self->got_read_eof = true; |
|||
return errSecEndOfData; |
|||
} |
|||
break; |
|||
case Network::Errno::AGAIN: |
|||
*dataLength = offset; |
|||
return errSSLWouldBlock; |
|||
default: |
|||
LOG_ERROR(Service_SSL, "Socket {} returned Network::Errno {}", |
|||
is_read ? "recv" : "send", err); |
|||
return errSecIO; |
|||
} |
|||
} |
|||
ASSERT(offset == requested_len); |
|||
return 0; |
|||
} |
|||
|
|||
private: |
|||
CFReleaser<SSLContextRef> context = nullptr; |
|||
bool got_read_eof = false; |
|||
|
|||
std::shared_ptr<Network::SocketBase> socket; |
|||
}; |
|||
|
|||
ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { |
|||
auto conn = std::make_unique<SSLConnectionBackendSecureTransport>(); |
|||
const Result res = conn->Init(); |
|||
if (res.IsFailure()) { |
|||
return res; |
|||
} |
|||
return conn; |
|||
} |
|||
|
|||
} // namespace Service::SSL
|
|||
Write
Preview
Loading…
Cancel
Save
Reference in new issue