Browse Source

speedy crypto?

pull/2746/head
Ribbit 5 months ago
parent
commit
235be9318e
  1. 54
      src/core/crypto/aes_util.cpp
  2. 50
      src/core/crypto/ctr_encryption_layer.cpp
  3. 100
      src/core/crypto/xts_encryption_layer.cpp
  4. 8
      src/core/file_sys/fssystem/fssystem_aes_ctr_storage.cpp
  5. 11
      src/core/file_sys/fssystem/fssystem_aes_xts_storage.cpp

54
src/core/crypto/aes_util.cpp

@ -2,6 +2,7 @@
// SPDX-License-Identifier: GPL-2.0-or-later // SPDX-License-Identifier: GPL-2.0-or-later
#include <array> #include <array>
#include <vector>
#include <mbedtls/cipher.h> #include <mbedtls/cipher.h>
#include "common/assert.h" #include "common/assert.h"
#include "common/logging/log.h" #include "common/logging/log.h"
@ -71,37 +72,42 @@ void AESCipher<Key, KeySize>::Transcode(const u8* src, std::size_t size, u8* des
mbedtls_cipher_reset(context); mbedtls_cipher_reset(context);
// Only ECB strictly requires block sized chunks.
const auto mode = mbedtls_cipher_get_cipher_mode(context);
std::size_t written = 0; std::size_t written = 0;
if (mbedtls_cipher_get_cipher_mode(context) == MBEDTLS_MODE_XTS) {
if (mode != MBEDTLS_MODE_ECB) {
mbedtls_cipher_update(context, src, size, dest, &written); mbedtls_cipher_update(context, src, size, dest, &written);
if (written != size) { if (written != size) {
LOG_WARNING(Crypto, "Not all data was decrypted requested={:016X}, actual={:016X}.",
LOG_WARNING(Crypto, "Not all data was processed requested={:016X}, actual={:016X}.",
size, written); size, written);
} }
} else {
const auto block_size = mbedtls_cipher_get_block_size(context);
if (size < block_size) {
std::vector<u8> block(block_size);
std::memcpy(block.data(), src, size);
Transcode(block.data(), block.size(), block.data(), op);
std::memcpy(dest, block.data(), size);
return;
}
return;
}
// ECB path: operate in block sized chunks and mirror previous behavior.
const auto block_size = mbedtls_cipher_get_block_size(context);
if (size < block_size) {
std::vector<u8> block(block_size);
std::memcpy(block.data(), src, size);
Transcode(block.data(), block.size(), block.data(), op);
std::memcpy(dest, block.data(), size);
return;
}
for (std::size_t offset = 0; offset < size; offset += block_size) {
auto length = std::min<std::size_t>(block_size, size - offset);
mbedtls_cipher_update(context, src + offset, length, dest + offset, &written);
if (written != length) {
if (length < block_size) {
std::vector<u8> block(block_size);
std::memcpy(block.data(), src + offset, length);
Transcode(block.data(), block.size(), block.data(), op);
std::memcpy(dest + offset, block.data(), length);
return;
}
LOG_WARNING(Crypto, "Not all data was decrypted requested={:016X}, actual={:016X}.",
length, written);
for (std::size_t offset = 0; offset < size; offset += block_size) {
const auto length = std::min<std::size_t>(block_size, size - offset);
mbedtls_cipher_update(context, src + offset, length, dest + offset, &written);
if (written != length) {
if (length < block_size) {
std::vector<u8> block(block_size);
std::memcpy(block.data(), src + offset, length);
Transcode(block.data(), block.size(), block.data(), op);
std::memcpy(dest + offset, block.data(), length);
return;
} }
LOG_WARNING(Crypto, "Not all data was processed requested={:016X}, actual={:016X}.",
length, written);
} }
} }
} }

50
src/core/crypto/ctr_encryption_layer.cpp

@ -12,29 +12,47 @@ CTREncryptionLayer::CTREncryptionLayer(FileSys::VirtualFile base_, Key128 key_,
: EncryptionLayer(std::move(base_)), base_offset(base_offset_), cipher(key_, Mode::CTR) {} : EncryptionLayer(std::move(base_)), base_offset(base_offset_), cipher(key_, Mode::CTR) {}
std::size_t CTREncryptionLayer::Read(u8* data, std::size_t length, std::size_t offset) const { std::size_t CTREncryptionLayer::Read(u8* data, std::size_t length, std::size_t offset) const {
if (length == 0)
if (length == 0) {
return 0; return 0;
}
std::size_t total_read = 0;
// Handle an initial misaligned portion if needed.
const auto sector_offset = offset & 0xF; const auto sector_offset = offset & 0xF;
if (sector_offset == 0) {
UpdateIV(base_offset + offset);
std::vector<u8> raw = base->ReadBytes(length, offset);
cipher.Transcode(raw.data(), raw.size(), data, Op::Decrypt);
return length;
if (sector_offset != 0) {
const std::size_t aligned_off = offset - sector_offset;
std::array<u8, 0x10> block{};
const std::size_t got = base->Read(block.data(), block.size(), aligned_off);
if (got == 0) {
return 0;
}
UpdateIV(base_offset + aligned_off);
cipher.Transcode(block.data(), got, block.data(), Op::Decrypt);
const std::size_t to_copy = std::min<std::size_t>(length, got > sector_offset ? got - sector_offset : 0);
if (to_copy > 0) {
std::memcpy(data, block.data() + sector_offset, to_copy);
data += to_copy;
offset += to_copy;
length -= to_copy;
total_read += to_copy;
}
} }
// offset does not fall on block boundary (0x10)
std::vector<u8> block = base->ReadBytes(0x10, offset - sector_offset);
UpdateIV(base_offset + offset - sector_offset);
cipher.Transcode(block.data(), block.size(), block.data(), Op::Decrypt);
std::size_t read = 0x10 - sector_offset;
if (length == 0) {
return total_read;
}
if (length + sector_offset < 0x10) {
std::memcpy(data, block.data() + sector_offset, std::min<u64>(length, read));
return std::min<u64>(length, read);
// Now aligned to 0x10
UpdateIV(base_offset + offset);
const std::size_t got = base->Read(data, length, offset);
if (got > 0) {
cipher.Transcode(data, got, data, Op::Decrypt);
total_read += got;
} }
std::memcpy(data, block.data() + sector_offset, read);
return read + Read(data + read, length - read, offset + read);
return total_read;
} }
void CTREncryptionLayer::SetIV(const IVData& iv_) { void CTREncryptionLayer::SetIV(const IVData& iv_) {

100
src/core/crypto/xts_encryption_layer.cpp

@ -16,44 +16,80 @@ XTSEncryptionLayer::XTSEncryptionLayer(FileSys::VirtualFile base_, Key256 key_)
: EncryptionLayer(std::move(base_)), cipher(key_, Mode::XTS) {} : EncryptionLayer(std::move(base_)), cipher(key_, Mode::XTS) {}
std::size_t XTSEncryptionLayer::Read(u8* data, std::size_t length, std::size_t offset) const { std::size_t XTSEncryptionLayer::Read(u8* data, std::size_t length, std::size_t offset) const {
if (length == 0)
if (length == 0) {
return 0; return 0;
}
std::size_t total_read = 0;
const std::size_t sector_size = XTS_SECTOR_SIZE;
const std::size_t sector_offset = offset % sector_size;
const auto sector_offset = offset & 0x3FFF;
if (sector_offset == 0) {
if (length % XTS_SECTOR_SIZE == 0) {
std::vector<u8> raw = base->ReadBytes(length, offset);
cipher.XTSTranscode(raw.data(), raw.size(), data, offset / XTS_SECTOR_SIZE,
XTS_SECTOR_SIZE, Op::Decrypt);
return raw.size();
// Handle initial unaligned part within a sector.
if (sector_offset != 0) {
const std::size_t aligned_off = offset - sector_offset;
std::array<u8, XTS_SECTOR_SIZE> block{};
std::size_t got = base->Read(block.data(), sector_size, aligned_off);
if (got == 0) {
return 0;
} }
if (length > XTS_SECTOR_SIZE) {
const auto rem = length % XTS_SECTOR_SIZE;
const auto read = length - rem;
return Read(data, read, offset) + Read(data + read, rem, offset + read);
if (got < sector_size) {
std::fill(block.begin() + got, block.end(), 0);
}
cipher.XTSTranscode(block.data(), sector_size, block.data(), aligned_off / sector_size,
sector_size, Op::Decrypt);
const std::size_t to_copy = std::min<std::size_t>(length, got > sector_offset ? got - sector_offset : 0);
if (to_copy > 0) {
std::memcpy(data, block.data() + sector_offset, to_copy);
data += to_copy;
offset += to_copy;
length -= to_copy;
total_read += to_copy;
} }
std::vector<u8> buffer = base->ReadBytes(XTS_SECTOR_SIZE, offset);
if (buffer.size() < XTS_SECTOR_SIZE)
buffer.resize(XTS_SECTOR_SIZE);
cipher.XTSTranscode(buffer.data(), buffer.size(), buffer.data(), offset / XTS_SECTOR_SIZE,
XTS_SECTOR_SIZE, Op::Decrypt);
std::memcpy(data, buffer.data(), (std::min)(buffer.size(), length));
return (std::min)(buffer.size(), length);
} }
// offset does not fall on block boundary (0x4000)
std::vector<u8> block = base->ReadBytes(0x4000, offset - sector_offset);
if (block.size() < XTS_SECTOR_SIZE)
block.resize(XTS_SECTOR_SIZE);
cipher.XTSTranscode(block.data(), block.size(), block.data(),
(offset - sector_offset) / XTS_SECTOR_SIZE, XTS_SECTOR_SIZE, Op::Decrypt);
const std::size_t read = XTS_SECTOR_SIZE - sector_offset;
if (length + sector_offset < XTS_SECTOR_SIZE) {
std::memcpy(data, block.data() + sector_offset, std::min<u64>(length, read));
return std::min<u64>(length, read);
if (length == 0) {
return total_read;
} }
std::memcpy(data, block.data() + sector_offset, read);
return read + Read(data + read, length - read, offset + read);
// Process aligned middle inplace, in sector sized multiples.
while (length >= sector_size) {
const std::size_t req = (length / sector_size) * sector_size;
const std::size_t got = base->Read(data, req, offset);
if (got == 0) {
return total_read;
}
const std::size_t got_rounded = got - (got % sector_size);
if (got_rounded > 0) {
cipher.XTSTranscode(data, got_rounded, data, offset / sector_size, sector_size,
Op::Decrypt);
data += got_rounded;
offset += got_rounded;
length -= got_rounded;
total_read += got_rounded;
}
// If we didn't get a full sector next, break to handle tail.
if (got_rounded != got) {
break;
}
}
// Handle tail within a sector, if any.
if (length > 0) {
std::array<u8, XTS_SECTOR_SIZE> block{};
const std::size_t got = base->Read(block.data(), sector_size, offset);
if (got > 0) {
if (got < sector_size) {
std::fill(block.begin() + got, block.end(), 0);
}
cipher.XTSTranscode(block.data(), sector_size, block.data(), offset / sec
tor_size, sector_size, Op::Decrypt);
const std::size_t to_copy = std::min<std::size_t>(length, got);
std::memcpy(data, block.data(), to_copy);
total_read += to_copy;
}
}
return total_read;
} }
} // namespace Core::Crypto } // namespace Core::Crypto

8
src/core/file_sys/fssystem/fssystem_aes_ctr_storage.cpp

@ -5,6 +5,7 @@
// SPDX-License-Identifier: GPL-2.0-or-later // SPDX-License-Identifier: GPL-2.0-or-later
#include "common/alignment.h" #include "common/alignment.h"
#include <vector>
#include "common/swap.h" #include "common/swap.h"
#include "core/file_sys/fssystem/fssystem_aes_ctr_storage.h" #include "core/file_sys/fssystem/fssystem_aes_ctr_storage.h"
#include "core/file_sys/fssystem/fssystem_utility.h" #include "core/file_sys/fssystem/fssystem_utility.h"
@ -88,11 +89,14 @@ size_t AesCtrStorage::Write(const u8* buffer, size_t size, size_t offset) {
s64 cur_offset = 0; s64 cur_offset = 0;
// Get a pooled buffer. // Get a pooled buffer.
std::vector<char> pooled_buffer(BlockSize);
thread_local std::vector<u8> pooled_buffer;
if (pooled_buffer.size() < BlockSize) {
pooled_buffer.resize(BlockSize);
}
while (remaining > 0) { while (remaining > 0) {
// Determine data we're writing and where. // Determine data we're writing and where.
const size_t write_size = std::min(pooled_buffer.size(), remaining); const size_t write_size = std::min(pooled_buffer.size(), remaining);
u8* write_buf = reinterpret_cast<u8*>(pooled_buffer.data());
u8* write_buf = pooled_buffer.data();
// Encrypt the data. // Encrypt the data.
m_cipher->SetIV(ctr); m_cipher->SetIV(ctr);

11
src/core/file_sys/fssystem/fssystem_aes_xts_storage.cpp

@ -4,6 +4,9 @@
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project // SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later // SPDX-License-Identifier: GPL-2.0-or-later
#include <algorithm>
#include <vector>
#include "common/alignment.h" #include "common/alignment.h"
#include "common/swap.h" #include "common/swap.h"
#include "core/file_sys/fssystem/fssystem_aes_xts_storage.h" #include "core/file_sys/fssystem/fssystem_aes_xts_storage.h"
@ -68,9 +71,13 @@ size_t AesXtsStorage::Read(u8* buffer, size_t size, size_t offset) const {
static_cast<size_t>(offset - Common::AlignDown(offset, m_block_size)); static_cast<size_t>(offset - Common::AlignDown(offset, m_block_size));
const size_t data_size = (std::min)(size, m_block_size - skip_size); const size_t data_size = (std::min)(size, m_block_size - skip_size);
// Decrypt into a pooled buffer.
// Decrypt into a thread-local pooled buffer to avoid per-call allocations.
{ {
std::vector<char> tmp_buf(m_block_size, 0);
thread_local std::vector<u8> tmp_buf;
if (tmp_buf.size() < m_block_size) {
tmp_buf.resize(m_block_size);
}
std::fill(tmp_buf.begin(), tmp_buf.begin() + m_block_size, 0);
std::memcpy(tmp_buf.data() + skip_size, buffer, data_size); std::memcpy(tmp_buf.data() + skip_size, buffer, data_size);
m_cipher->SetIV(ctr); m_cipher->SetIV(ctr);

Loading…
Cancel
Save