/*
 * exception_test3.c
 *
 * Test for behavior of stack modification in mach exception handlers.
 *
 * Cyrus Harmon and Alastair Bridgewater, Dec 2006.
 */

#include <mach/mach.h>
#include <mach/mach_error.h>
#include <mach/mach_types.h>
#include <mach/sync_policy.h>
#include <mach/machine/thread_state.h>
#include <mach/machine/thread_status.h>
#include <sys/_types.h>
#include <sys/ucontext.h>
#include <pthread.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>


typedef unsigned long u32;

/* This executes in the faulting thread as part of the signal
 * emulation.  It is passed a context with the uc_mcontext field
 * pointing to a valid block of memory. */
void build_fake_signal_context(struct ucontext *context,
			       i386_thread_state_t *thread_state) {
    pthread_sigmask(0, NULL, &context->uc_sigmask);

    context->uc_mcontext->ss.eax = thread_state->eax;
    context->uc_mcontext->ss.ebx = thread_state->ebx;
    context->uc_mcontext->ss.ecx = thread_state->ecx;
    context->uc_mcontext->ss.edx = thread_state->edx;
    context->uc_mcontext->ss.edi = thread_state->edi;
    context->uc_mcontext->ss.esi = thread_state->esi;
    context->uc_mcontext->ss.ebp = thread_state->ebp;
    context->uc_mcontext->ss.esp = thread_state->esp;
    context->uc_mcontext->ss.ss = thread_state->ss;
    context->uc_mcontext->ss.eflags = thread_state->eflags;
    context->uc_mcontext->ss.eip = thread_state->eip;
    context->uc_mcontext->ss.cs = thread_state->cs;
    context->uc_mcontext->ss.ds = thread_state->ds;
    context->uc_mcontext->ss.es = thread_state->es;
    context->uc_mcontext->ss.fs = thread_state->fs;
    context->uc_mcontext->ss.gs = thread_state->gs;
}

/* This executes in the faulting thread as part of the signal
 * emulation.  It is effectively the inverse operation from above. */
void update_thread_state_from_context(i386_thread_state_t *thread_state,
				      struct ucontext *context) {
    thread_state->eax = context->uc_mcontext->ss.eax;
    thread_state->ebx = context->uc_mcontext->ss.ebx;
    thread_state->ecx = context->uc_mcontext->ss.ecx;
    thread_state->edx = context->uc_mcontext->ss.edx;
    thread_state->edi = context->uc_mcontext->ss.edi;
    thread_state->esi = context->uc_mcontext->ss.esi;
    thread_state->ebp = context->uc_mcontext->ss.ebp;
    thread_state->esp = context->uc_mcontext->ss.esp;
    thread_state->ss = context->uc_mcontext->ss.ss;
    thread_state->eflags = context->uc_mcontext->ss.eflags;
    thread_state->eip = context->uc_mcontext->ss.eip;
    thread_state->cs = context->uc_mcontext->ss.cs;
    thread_state->ds = context->uc_mcontext->ss.ds;
    thread_state->es = context->uc_mcontext->ss.es;
    thread_state->fs = context->uc_mcontext->ss.fs;
    thread_state->gs = context->uc_mcontext->ss.gs;

    pthread_sigmask(SIG_SETMASK, &context->uc_sigmask, NULL);

    return;
}

void
bogus_handler(int signal, siginfo_t *siginfo, void *void_context) {
    static int times = 3;

    if (!--times) exit(0);

    printf("dude, we're here!\n");
    return;
}


/* Modify a context to push new data on its stack. */
void push_context(u32 data, i386_thread_state_t *context)
{
    u32 *stack_pointer;

    stack_pointer = (u32*) context->esp;
    *(--stack_pointer) = data;
    context->esp = (unsigned int) stack_pointer;
}

void align_context_stack(i386_thread_state_t *context)
{
    /* 16byte align the stack (provided that the stack is, as it
     * should be, 4byte aligned. */
    while (context->esp & 15) push_context(0, context);
}

/* Stack allocation starts with a context that has a mod-4 ESP value
 * and needs to leave a context with a mod-16 ESP that will restore
 * the old ESP value and other register state when activated.  The
 * first part of this is the recovery trampoline, which loads ESP from
 * EBP, pops EBP, and returns. */
asm("_stack_allocation_recover: movl %ebp, %esp; popl %ebp; ret;");

void open_stack_allocation(i386_thread_state_t *context)
{
    void stack_allocation_recover(void);

    push_context(context->eip, context);
    push_context(context->ebp, context);
    context->ebp = context->esp;
    context->eip = stack_allocation_recover;

    align_context_stack();
}

/* Stack allocation of data starts with a context with a mod-16 ESP
 * value and reserves some space on it by manipulating the ESP
 * register. */
void *stack_allocate(i386_thread_state_t *context, size_t size)
{
    /* round up size to 16byte multiple */
    size = (size + 15) & -16;

    context->esp = ((u32)context->esp) - size;

    return (void *)context->esp;
}

/* Arranging to invoke a C function is tricky, as we have to assume
 * cdecl calling conventions (caller removes args) and x86/darwin
 * alignment requirements.  The simplest way to arrange this,
 * actually, is to open a new stack allocation. */
void call_c_function_in_context(i386_thread_state_t *context,
				void *function,
				int nargs,
				...)
{
    va_list ap;
    int i;

    /* Set up to restore stack on exit. */
    open_stack_allocation(context);

    /* Have to keep stack 16byte aligned on x86/darwin. */
    for (i = (3 & -nargs); i; i--) {
	push_context(0, context);
    }

    va_start(ap, nargs);
    for (i = 0; i < nargs; i++) {
	push_context(va_arg(ap, u32), context);
    }
    va_end(ap);

    push_context(context->eip, context);
    context->eip = function;
}

void signal_emulation_wrapper(i386_thread_state_t thread_state,
			      siginfo_t *siginfo,
			      int signal,
			      void (*handler)(int, siginfo_t *, void *))
{
    struct ucontext context;
    struct mcontext regs;

    context.uc_mcontext = &regs;
    build_fake_signal_context(&context, thread_state);

    handler(signal, siginfo, &context);

    update_thread_state_from_context(thread_state, &context);

    /* Trap to restore the signal context. */
    asm volatile ("movl %0, %%eax; .long 0xffff0b0f": : "r" (&context));
}

void dump_context(i386_thread_state_t *context)
{
    int i;
    u32 *stack_pointer;

    printf("eax: %08lx  ecx: %08lx  edx: %08lx  ebx: %08lx\n",
	   context->eax, context->ecx, context->edx, context->ebx);
    printf("esp: %08lx  ebp: %08lx  esi: %08lx  edi: %08lx\n",
	   context->esp, context->ebp, context->esi, context->edi);
    printf("eip: %08lx  eflags: %08lx\n",
	   context->eip, context->eflags);
    printf("cs: %04hx  ds: %04hx  es: %04hx  "
	   "ss: %04hx  fs: %04hx  gs: %04hx\n",
	   context->cs, context->ds, context->es,
	   context->ss, context->fs, context->gs);

    stack_pointer = (u32 *)context->esp;
    for (i = 0; i < 32; i++) {
	printf("%08lx: %08lx\n", context->esp + (i * 4), stack_pointer[i]);
    }
}

kern_return_t
catch_exception_raise(mach_port_t exception_port,
                     mach_port_t thread,
                     mach_port_t task, 
                     exception_type_t exception,
                     exception_data_t code_vector,
                     mach_msg_type_number_t code_count)
{
    kern_return_t ret;
    i386_exception_state_t exception_state;
    mach_msg_type_number_t exception_state_count = i386_EXCEPTION_STATE_COUNT;
    i386_thread_state_t thread_state;
    i386_thread_state_t backup_thread_state;
    i386_thread_state_t *target_thread_state;
    mach_msg_type_number_t thread_state_count = i386_THREAD_STATE_COUNT;
    siginfo_t *siginfo;
    int signal;

    switch (exception) {

    case EXC_BAD_ACCESS:
        ret = thread_get_state(thread,
                               i386_THREAD_STATE,
                               (thread_state_t)&thread_state,
                               &thread_state_count);

        ret = thread_get_state(thread,
                               i386_EXCEPTION_STATE,
                               (thread_state_t)&exception_state,
                               &exception_state_count);

	backup_thread_state = thread_state;

        signal = SIGBUS;

	open_stack_allocation(&thread_state);

	/* Save thread state */
	target_thread_state =
	    stack_allocate(&thread_state, sizeof(*target_thread_state));
	(*target_thread_state) = backup_thread_state;

	/* Set up siginfo */
	siginfo = stack_allocate(&thread_state, sizeof(*siginfo));
        /* what do we need to put in our fake siginfo?  It looks like
         * the x86 code only uses si_signo and si_adrr. */
        siginfo->si_signo = signal;
        siginfo->si_addr = (void*)exception_state.faultvaddr;

	/* Set up to call the signal handler emulator */
	call_c_function_in_context(&thread_state,
				   signal_emulation_wrapper,
				   4,
				   target_thread_state,
				   signal,
				   siginfo,
				   bogus_handler);

        ret = thread_set_state(thread,
                               i386_THREAD_STATE,
                               (thread_state_t)&thread_state,
                               thread_state_count);
        return KERN_SUCCESS;

    case EXC_BAD_INSTRUCTION:
        ret = thread_get_state(thread,
                               i386_THREAD_STATE,
                               (thread_state_t)&thread_state,
                               &thread_state_count);

	if (0xffff0b0f == *((u32 *)thread_state.eip)) {
	    /* fake sigreturn. */

	    /* When we get here, thread_state.eax is a pointer to a
	     * thread_state to restore. */
	    thread_state = *((thread_state *)thread_state.eax);

	    ret = thread_set_state(thread,
				   i386_THREAD_STATE,
				   (thread_state_t)&thread_state,
				   thread_state_count);
	    return KERN_SUCCESS;
	}

	dump_context(&thread_state);

	/* FALL THROUGH */

    default:
        return KERN_INVALID_RIGHT;
    }
}

extern boolean_t exc_server();

void *
mach_exception_handler(void *port)
{
    mach_msg_server(exc_server, 2048, (mach_port_t) port, 0);
    /* mach_msg_server should never return, but it should dispatch mach
     * exceptions to our catch_exception_raise function
     */
    abort();
}

void start_exception_thread(void) {
    thread_port_t main_thread;
    static mach_port_t exception_port = MACH_PORT_NULL;
    kern_return_t ret;  
    pthread_attr_t attr;
    pthread_t returned_thread = (pthread_t) 0;

    main_thread = mach_thread_self();
    
    ret = mach_port_allocate(mach_task_self(),
                             MACH_PORT_RIGHT_RECEIVE,
                             &exception_port);

    ret = mach_port_insert_right(mach_task_self(),
                                 exception_port,
                                 exception_port,
                                 MACH_MSG_TYPE_MAKE_SEND);
    
    ret = thread_set_exception_ports(main_thread,
                                     EXC_MASK_BAD_ACCESS | EXC_MASK_BAD_INSTRUCTION,
                                     exception_port,
                                     EXCEPTION_DEFAULT,
                                     THREAD_STATE_NONE);
    
    fprintf(stderr, "Creating mach_exception_handler thread!\n");
    pthread_attr_init(&attr);
    pthread_create(&returned_thread, &attr,
                   mach_exception_handler, (void*) exception_port);
    pthread_attr_destroy(&attr);
}

int main(void)
{
    start_exception_thread();

    /* SIGSEGV */
    *((char *) 0) = 0;

    return 0;
}

/* EOF */
