Kernel: Add safe_memcpy, safe_memset and safe_strnlen

These special functions can be used to safely copy/set memory or
determine the length of a string, e.g. provided by user mode.

In the event of a page fault, safe_memcpy/safe_memset will return
false and safe_strnlen will return -1.
This commit is contained in:
Tom 2020-09-10 19:49:31 -06:00 committed by Andreas Kling
parent 1c86ab0108
commit 7d1b8417bd
Notes: sideshowbarker 2024-07-19 02:42:40 +09:00
2 changed files with 180 additions and 3 deletions

View file

@ -150,10 +150,11 @@ static void dump(const RegisterState& regs)
klog() << "cr0=" << String::format("%08x", cr0) << " cr2=" << String::format("%08x", cr2) << " cr3=" << String::format("%08x", cr3) << " cr4=" << String::format("%08x", cr4);
auto process = Process::current();
if (process && process->validate_read((void*)regs.eip, 8)) {
u8 code[8];
void* fault_at;
if (process && safe_memcpy(code, (void*)regs.eip, 8, fault_at)) {
SmapDisabler disabler;
u8* codeptr = (u8*)regs.eip;
klog() << "code: " << String::format("%02x", codeptr[0]) << " " << String::format("%02x", codeptr[1]) << " " << String::format("%02x", codeptr[2]) << " " << String::format("%02x", codeptr[3]) << " " << String::format("%02x", codeptr[4]) << " " << String::format("%02x", codeptr[5]) << " " << String::format("%02x", codeptr[6]) << " " << String::format("%02x", codeptr[7]);
klog() << "code: " << String::format("%02x", code[0]) << " " << String::format("%02x", code[1]) << " " << String::format("%02x", code[2]) << " " << String::format("%02x", code[3]) << " " << String::format("%02x", code[4]) << " " << String::format("%02x", code[5]) << " " << String::format("%02x", code[6]) << " " << String::format("%02x", code[7]);
}
}
@ -212,6 +213,171 @@ void fpu_exception_handler(TrapFrame*)
asm volatile("clts");
}
extern "C" u8* safe_memcpy_ins_1;
extern "C" u8* safe_memcpy_1_faulted;
extern "C" u8* safe_memcpy_ins_2;
extern "C" u8* safe_memcpy_2_faulted;
extern "C" u8* safe_strnlen_ins;
extern "C" u8* safe_strnlen_faulted;
extern "C" u8* safe_memset_ins_1;
extern "C" u8* safe_memset_1_faulted;
extern "C" u8* safe_memset_ins_2;
extern "C" u8* safe_memset_2_faulted;
bool safe_memcpy(void* dest_ptr, const void* src_ptr, size_t n, void*& fault_at)
{
size_t dest = (size_t)dest_ptr;
size_t src = (size_t)src_ptr;
size_t remainder;
// FIXME: Support starting at an unaligned address.
if (!(dest & 0x3) && !(src & 0x3) && n >= 12) {
size_t size_ts = n / sizeof(size_t);
asm volatile(
"xor %[fault_at], %[fault_at] \n"
".global safe_memcpy_ins_1 \n"
"safe_memcpy_ins_1: \n"
"rep movsl \n"
".global safe_memcpy_1_faulted \n"
"safe_memcpy_1_faulted: \n" // handle_safe_access_fault() set edx to the fault address!
: "=S" (src),
"=D" (dest),
"=c" (remainder),
[fault_at] "=d" (fault_at)
: "S" (src),
"D" (dest),
"c" (size_ts)
: "memory");
if (remainder != 0)
return false; // fault_at is already set!
n -= size_ts * sizeof(size_t);
if (n == 0) {
fault_at = nullptr;
return true;
}
}
asm volatile(
"xor %[fault_at], %[fault_at] \n"
".global safe_memcpy_ins_2 \n"
"safe_memcpy_ins_2: \n"
"rep movsb \n"
".global safe_memcpy_2_faulted \n"
"safe_memcpy_2_faulted: \n" // handle_safe_access_fault() set edx to the fault address!
: "=c" (remainder),
[fault_at] "=d" (fault_at)
: "S" (src),
"D" (dest),
"c" (n)
: "memory");
if (remainder != 0)
return false; // fault_at is already set!
fault_at = nullptr;
return true;
}
ssize_t safe_strnlen(const char* str, size_t max_n, void*& fault_at)
{
ssize_t count = 0;
asm volatile(
"xor %[fault_at], %[fault_at] \n"
"1: \n"
"test %[max_n], %[max_n] \n"
"je 2f \n"
"dec %[max_n] \n"
".global safe_strnlen_ins \n"
"safe_strnlen_ins: \n"
"cmpb $0,(%[str], %[count], 1) \n"
"je 2f \n"
"inc %[count] \n"
"jmp 1b \n"
".global safe_strnlen_faulted \n"
"safe_strnlen_faulted: \n" // handle_safe_access_fault() set edx to the fault address!
"xor %[count_on_error], %[count_on_error] \n"
"dec %[count_on_error] \n" // return -1 on fault
"2:"
: [count_on_error] "=c" (count),
[fault_at] "=d" (fault_at)
: [str] "b" (str),
[count] "c" (count),
[max_n] "d" (max_n)
);
if (count >= 0)
fault_at = nullptr;
return count;
}
bool safe_memset(void* dest_ptr, int c, size_t n, void*& fault_at)
{
size_t dest = (size_t)dest_ptr;
size_t remainder;
// FIXME: Support starting at an unaligned address.
if (!(dest & 0x3) && n >= 12) {
size_t size_ts = n / sizeof(size_t);
size_t expanded_c = (u8)c;
expanded_c |= expanded_c << 8;
expanded_c |= expanded_c << 16;
asm volatile(
"xor %[fault_at], %[fault_at] \n"
".global safe_memset_ins_1 \n"
"safe_memset_ins_1: \n"
"rep stosl \n"
".global safe_memset_1_faulted \n"
"safe_memset_1_faulted: \n" // handle_safe_access_fault() set edx to the fault address!
: "=D" (dest),
"=c" (remainder),
[fault_at] "=d" (fault_at)
: "D" (dest),
"a" (expanded_c),
"c" (size_ts)
: "memory");
if (remainder != 0)
return false; // fault_at is already set!
n -= size_ts * sizeof(size_t);
if (remainder == 0) {
fault_at = nullptr;
return true;
}
}
asm volatile(
"xor %[fault_at], %[fault_at] \n"
".global safe_memset_ins_2 \n"
"safe_memset_ins_2: \n"
"rep stosb \n"
".global safe_memset_2_faulted \n"
"safe_memset_2_faulted: \n" // handle_safe_access_fault() set edx to the fault address!
: "=D" (dest),
"=c" (remainder),
[fault_at] "=d" (fault_at)
: "D" (dest),
"c" (n),
"a" (c)
: "memory");
if (remainder != 0)
return false; // fault_at is already set!
fault_at = nullptr;
return true;
}
static bool handle_safe_access_fault(RegisterState& regs, u32 fault_address)
{
// If we detect that the fault happened in safe_memcpy() safe_strnlen(),
// or safe_memset() then resume at the appropriate _faulted label
if (regs.eip == (FlatPtr)&safe_memcpy_ins_1)
regs.eip = (FlatPtr)&safe_memcpy_1_faulted;
else if (regs.eip == (FlatPtr)&safe_memcpy_ins_2)
regs.eip = (FlatPtr)&safe_memcpy_2_faulted;
else if (regs.eip == (FlatPtr)&safe_strnlen_ins)
regs.eip = (FlatPtr)&safe_strnlen_faulted;
else if (regs.eip == (FlatPtr)&safe_memset_ins_1)
regs.eip = (FlatPtr)&safe_memset_1_faulted;
else if (regs.eip == (FlatPtr)&safe_memset_ins_2)
regs.eip = (FlatPtr)&safe_memset_2_faulted;
else
return false;
regs.edx = fault_address;
return true;
}
// 14: Page Fault
EH_ENTRY(14, page_fault);
void page_fault_handler(TrapFrame* trap)
@ -248,6 +414,13 @@ void page_fault_handler(TrapFrame* trap)
auto response = MM.handle_page_fault(PageFault(regs.exception_code, VirtualAddress(fault_address)));
if (response == PageFaultResponse::ShouldCrash || response == PageFaultResponse::OutOfMemory) {
if (!(regs.cs & 3) && handle_safe_access_fault(regs, fault_address)) {
// If this would be a ring0 (kernel) fault and the fault was triggered by
// safe_memcpy, safe_strnlen, or safe_memset then we resume execution at
// the appropriate _fault label rather than crashing
return;
}
if (response != PageFaultResponse::OutOfMemory) {
if (current_thread->has_signal_handler(SIGSEGV)) {
current_thread->send_urgent_signal_to_self(SIGSEGV);

View file

@ -279,6 +279,10 @@ void flush_idt();
void load_task_register(u16 selector);
void handle_crash(RegisterState&, const char* description, int signal, bool out_of_memory = false);
[[nodiscard]] bool safe_memcpy(void* dest_ptr, const void* src_ptr, size_t n, void*& fault_at);
[[nodiscard]] ssize_t safe_strnlen(const char* str, size_t max_n, void*& fault_at);
[[nodiscard]] bool safe_memset(void* dest_ptr, int c, size_t n, void*& fault_at);
#define LSW(x) ((u32)(x)&0xFFFF)
#define MSW(x) (((u32)(x) >> 16) & 0xFFFF)
#define LSB(x) ((x)&0xFF)