diff --git a/compiler/compile.c b/compiler/compile.c index bad22b8..0dfd466 100644 --- a/compiler/compile.c +++ b/compiler/compile.c @@ -63,6 +63,29 @@ void emit_function(struct context *ctx, struct function *fn) { } // Obligatory program return emit_insn(fn, OP(OP_RET, 0)); + + if (ctx->parent) { + fn->local_count = ctx->var_counter; + } else { + // Root function + fn->local_count = 0; + } +} + +static void emit_store(struct function *fn, int kind, size_t index) { + switch (kind) { + case NAME_LOCAL: + emit_insn(fn, OP(OP_STL, index)); + break; + case NAME_GLOBAL: + emit_insn(fn, OP(OP_STG, index)); + break; + case NAME_ARGUMENT: + emit_insn(fn, OP(OP_STARG, index)); + break; + default: + abort(); + } } void emit(struct function *fn, struct context *ctx, struct node *expr) { @@ -93,21 +116,17 @@ void emit(struct function *fn, struct context *ctx, struct node *expr) { int kind; size_t index; - assert(ctx_lookup_name(ctx, n0->n_ident, &kind, &index) == 0); + if (ctx_lookup_name(ctx, n0->n_ident, &kind, &index) != 0) { + fprintf(stderr, "Unresolved reference to %s\n", n0->n_ident); + abort(); + } n1 = caddr(expr); assert(null_q(cdr(cddr(expr)))); // (define ident value) // Emit value for n1 emit(fn, ctx, n1); - switch (kind) { - case NAME_LOCAL: - emit_insn(fn, OP(OP_STL, index)); - break; - case NAME_GLOBAL: - emit_insn(fn, OP(OP_STG, index)); - break; - } + emit_store(fn, kind, index); } else if (cons_q(n0)) { struct node *name = car(n0); struct node *args = cdr(n0); @@ -125,7 +144,10 @@ void emit(struct function *fn, struct context *ctx, struct node *expr) { struct context new_ctx; new_fn->args = args; new_fn->body = body; + new_fn->local_count = 0; + assert(!ctx->parent); ctx_init(&new_ctx, ctx); + new_ctx.var_counter = 0; new_ctx.owner = new_fn; emit_function(&new_ctx, new_fn); @@ -142,10 +164,9 @@ void emit(struct function *fn, struct context *ctx, struct node *expr) { size_t index; assert(ctx_lookup_name(ctx, n0->n_ident, &kind, &index) == 0); - assert(kind == NAME_GLOBAL); - emit(fn, ctx, n1); - emit_insn(fn, OP(OP_STG, index)); + + emit_store(fn, kind, index); return; } else if (!strcmp(n0->n_ident, "use")) { // Ignore @@ -279,6 +300,10 @@ void emit(struct function *fn, struct context *ctx, struct node *expr) { case NAME_LOCAL: emit_insn(fn, OP(OP_LCALL, index)); break; + case NAME_ARGUMENT: + emit_insn(fn, OP(OP_LDARG, index)); + emit_insn(fn, OP(OP_CALL, 0)); + break; case NAME_GLOBAL: emit_insn(fn, OP(OP_GCALL, index)); break; diff --git a/compiler/include/unit.h b/compiler/include/unit.h index b3ae5ee..69d688e 100644 --- a/compiler/include/unit.h +++ b/compiler/include/unit.h @@ -35,6 +35,7 @@ struct ext_unit { struct function { size_t index; + size_t local_count; struct node *args, *body; struct vector bytecode; }; diff --git a/compiler/main.c b/compiler/main.c index 4325e55..6454835 100644 --- a/compiler/main.c +++ b/compiler/main.c @@ -117,6 +117,7 @@ static void write_unit(FILE *fp, struct unit *u) { ++argc; } bin_func.argc = argc; + bin_func.local_count = fn->local_count; bin_func.len = fn->bytecode.size * sizeof(uint32_t); fwrite(&bin_func, 1, sizeof(struct bin_func_entry), fp); fwrite(fn->bytecode.data, 1, fn->bytecode.size * sizeof(uint32_t), fp); diff --git a/compiler/unit.c b/compiler/unit.c index a00a2fa..0d76836 100644 --- a/compiler/unit.c +++ b/compiler/unit.c @@ -65,7 +65,6 @@ int ctx_lookup_name(struct context *ctx, const char *name, int *kind, size_t *in p = hash_lookup(&ctx->vars, name); if (p) { if (ctx->parent) { - // TODO only look into level-1 local contexts *kind = NAME_LOCAL; } else { *kind = NAME_GLOBAL; diff --git a/core/include/binary.h b/core/include/binary.h index faf930f..1c3040b 100644 --- a/core/include/binary.h +++ b/core/include/binary.h @@ -29,6 +29,7 @@ struct bin_ref_entry { struct bin_func_entry { uint32_t argc; + uint32_t local_count; uint32_t len; uint32_t data[0]; }; diff --git a/core/include/op.h b/core/include/op.h index 826064b..be596d1 100644 --- a/core/include/op.h +++ b/core/include/op.h @@ -28,6 +28,7 @@ #define OP_ISZ 0x46 #define OP_LDARG 0x4B +#define OP_STARG 0x4C #define OP_LDG 0x4E #define OP_STG 0x4F #define OP_LDF 0x50 @@ -37,7 +38,8 @@ #define OP_XCALL 0x60 #define OP_LCALL 0x61 #define OP_GCALL 0x62 -#define OP_JMP 0x63 -#define OP_BT 0x64 -#define OP_BF 0x65 +#define OP_CALL 0x63 +#define OP_JMP 0x64 +#define OP_BT 0x65 +#define OP_BF 0x66 #define OP_RET 0x6F diff --git a/vm/include/vmstate.h b/vm/include/vmstate.h index 2f098e5..731f2fe 100644 --- a/vm/include/vmstate.h +++ b/vm/include/vmstate.h @@ -5,6 +5,7 @@ #include "vector.h" #define MAXARG 12 +#define MAXLOC 64 struct vm_value; @@ -17,7 +18,9 @@ struct vm_ref_entry { }; struct vm_func_entry { - size_t argc; + size_t argc, local_count; + uint64_t local_regs[MAXLOC]; + uint64_t arg_regs[MAXARG]; uint32_t *bytecode; }; @@ -28,8 +31,6 @@ struct vm_state { uint64_t *call_stack; size_t csp, call_stack_size; - uint64_t arg_regs[MAXARG]; - struct vector ref_table; struct vector functions; diff --git a/vm/main.c b/vm/main.c index 586fe5e..c2bf844 100644 --- a/vm/main.c +++ b/vm/main.c @@ -103,14 +103,19 @@ found: for (size_t i = 0; i < hdr.func_table_size; ++i) { struct bin_func_entry ent; uint32_t *bytecode; - fread(&ent, 1, sizeof(struct bin_ref_entry), fp); + fread(&ent, 1, sizeof(struct bin_func_entry), fp); bytecode = malloc(ent.len); assert(bytecode); fread(bytecode, 1, ent.len, fp); struct vm_func_entry *func = vm_add_function(&vm); + if (i == 0) { + assert(!func->local_count); + } + assert(ent.local_count <= MAXLOC); func->argc = ent.argc; func->bytecode = bytecode; + func->local_count = ent.local_count; } fclose(fp); diff --git a/vm/vmstate.c b/vm/vmstate.c index 23490ff..1a550b2 100644 --- a/vm/vmstate.c +++ b/vm/vmstate.c @@ -97,8 +97,9 @@ struct vm_func_entry *vm_add_function(struct vm_state *vm) { void vm_call_index(struct vm_state *vm, size_t index) { struct vm_func_entry *func; assert(index < vm->functions.size); - assert(vm->csp); + assert(vm->csp > 2); + vm->call_stack[--vm->csp] = vm->fp; vm->call_stack[--vm->csp] = vm->ip; vm->fp = index; vm->ip = 0; @@ -106,7 +107,7 @@ void vm_call_index(struct vm_state *vm, size_t index) { func = vector_ref(&vm->functions, index); assert(func->argc <= MAXARG); for (size_t i = 0; i < func->argc; ++i) { - vm->arg_regs[i] = pop(vm); + func->arg_regs[i] = pop(vm); } } @@ -116,13 +117,19 @@ void vm_call_ref(struct vm_state *vm, struct vm_value *ref) { vm_call_index(vm, ref->v_func.fn_index); } -int vm_eval_opcode(struct vm_state *vm, uint32_t opcode) { +int vm_eval_opcode(struct vm_state *vm, struct vm_func_entry *func, uint32_t opcode) { uint64_t w0, w1; + int64_t sw0, sw1; size_t i0; ssize_t ii0; struct vm_ref_entry *r0; switch (opcode >> 24) { + case OP_ADD: + sw0 = pop_integer(vm); + sw1 = pop_integer(vm); + push_integer(vm, sw0 + sw1); + return 0; case OP_NOT: w0 = pop(vm); if (null_q(w0)) { @@ -169,7 +176,13 @@ int vm_eval_opcode(struct vm_state *vm, uint32_t opcode) { case OP_LDARG: i0 = opcode & 0xFFFFFF; assert(i0 < MAXARG); - push(vm, vm->arg_regs[i0]); + push(vm, func->arg_regs[i0]); + return 0; + case OP_STARG: + i0 = opcode & 0xFFFFFF; + w0 = pop(vm); + assert(i0 < MAXARG); + func->arg_regs[i0] = w0; return 0; case OP_LDG: i0 = opcode & 0xFFFFFF; @@ -188,9 +201,15 @@ int vm_eval_opcode(struct vm_state *vm, uint32_t opcode) { push_ref(vm, vm_func(0, i0)); return 0; case OP_STL: + i0 = opcode & 0xFFFFFF; + assert(i0 < func->local_count); + func->local_regs[i0] = pop(vm); + return 0; case OP_LDL: - printf("TODO implement local values\n"); - abort(); + i0 = opcode & 0xFFFFFF; + assert(i0 < func->local_count); + push(vm, func->local_regs[i0]); + return 0; // case OP_ISZ: w0 = pop(vm); @@ -217,6 +236,11 @@ int vm_eval_opcode(struct vm_state *vm, uint32_t opcode) { assert(func_q(w0)); vm_call_ref(vm, getref(w0)); return 0; + case OP_CALL: + w0 = pop(vm); + assert(func_q(w0)); + vm_call_ref(vm, getref(w0)); + return 0; case OP_BF: w0 = pop(vm); ii0 = sximm(opcode & 0xFFFFFF); @@ -231,7 +255,13 @@ int vm_eval_opcode(struct vm_state *vm, uint32_t opcode) { vm->ip += ii0 - 1; return 0; case OP_RET: - return -1; + if (vm->csp == vm->call_stack_size) { + // Return from main + return -1; + } + vm->ip = vm->call_stack[vm->csp++]; + vm->fp = vm->call_stack[vm->csp++]; + return 0; default: fprintf(stderr, "Undefined opcode: 0x%02hhx\n", opcode >> 24); abort(); @@ -242,5 +272,5 @@ int vm_eval_step(struct vm_state *vm) { assert(vm->fp < vm->functions.size); struct vm_func_entry *func = vector_ref(&vm->functions, vm->fp); uint32_t opcode = func->bytecode[vm->ip++]; - return vm_eval_opcode(vm, opcode); + return vm_eval_opcode(vm, func, opcode); }