#include "../../include/ssa/ssa.h" #include "../../include/std/mem.h" #include "../../include/std/log.h" #include "../../include/std/hash.h" #include "../../include/ast.h" #include #define throw(result) do { longjmp(context->env, (result)); } while(0) #define throw_if_error(result) \ do { \ int return_if_result; \ return_if_result = (result); \ if((return_if_result) != 0) \ throw(return_if_result); \ } while(0) static int compare_number(const void *a, const void *b) { const SsaNumber *lhs; const SsaNumber *rhs; lhs = a; rhs = b; if(rhs->type == lhs->type && rhs->value.integer == lhs->value.integer) return 0; return 1; } static usize hash_number(const u8 *data, usize size) { SsaNumber number; assert(size == sizeof(SsaNumber)); am_memcpy(&number, data, size); return number.value.integer; } SsaNumber create_ssa_integer(i64 value) { SsaNumber result; result.value.integer = value; result.type = SSA_NUMBER_TYPE_INTEGER; return result; } SsaNumber create_ssa_float(f64 value) { SsaNumber result; result.value.floating = value; result.type = SSA_NUMBER_TYPE_FLOAT; return result; } int ssa_init(Ssa *self, ScopedAllocator *allocator) { return_if_error(buffer_init(&self->instructions, allocator)); return_if_error(hash_map_init(&self->intermediates, allocator, sizeof(SsaIntermediateIndex), compare_number, hash_number)); return_if_error(hash_map_init(&self->strings, allocator, sizeof(SsaStringIndex), hash_compare_string, amal_hash_string)); self->intermediate_counter = 0; self->string_counter = 0; self->reg_counter = 0; self->func_counter = 0; return 0; } int ssa_get_unique_reg(Ssa *self, SsaRegister *result) { /* Overflow */ if(self->reg_counter + 1 < self->reg_counter) return -1; *result = self->reg_counter++; return 0; } static CHECK_RESULT int ssa_try_add_intermediate(Ssa *self, SsaNumber number, SsaIntermediateIndex *result_index) { bool exists; BufferView key; assert(result_index); key = create_buffer_view((const char*)&number, sizeof(number)); exists = hash_map_get(&self->intermediates, key, result_index); if(exists) return 0; /* Overflow */ if(self->intermediate_counter + 1 < self->intermediate_counter) return -1; *result_index = self->intermediate_counter; ++self->intermediate_counter; switch(number.type) { case SSA_NUMBER_TYPE_INTEGER: { amal_log_debug("i%u = %lld", *result_index, number.value.integer); break; } case SSA_NUMBER_TYPE_FLOAT: { amal_log_debug("i%u = %f", *result_index, number.value.floating); break; } } return hash_map_insert(&self->intermediates, key, result_index); } static CHECK_RESULT int ssa_try_add_string(Ssa *self, BufferView str, SsaStringIndex *result_index) { bool exists; assert(result_index); exists = hash_map_get(&self->strings, str, result_index); if(exists) return 0; /* Overflow */ if(self->string_counter + 1 < self->string_counter) return -1; *result_index = self->string_counter; ++self->string_counter; amal_log_debug("s%u = \"%.*s\"", *result_index, str.size, str.data); return hash_map_insert(&self->strings, str, result_index); } static CHECK_RESULT int ssa_add_ins_form1(Ssa *self, SsaInstructionType ins_type, SsaRegister lhs, u16 rhs) { usize index; index = self->instructions.size; return_if_error(buffer_append(&self->instructions, NULL, sizeof(u8) + sizeof(SsaRegister) + sizeof(u16))); self->instructions.data[index + 0] = ins_type; *(SsaRegister*)&self->instructions.data[index + 1] = lhs; *(u16*)&self->instructions.data[index + 3] = rhs; return 0; } static const char* binop_type_to_string(SsaInstructionType binop_type) { assert(binop_type >= SSA_ADD && binop_type <= SSA_DIV); switch(binop_type) { case SSA_ADD: return "+"; case SSA_SUB: return "-"; case SSA_MUL: return "*"; case SSA_DIV: return "/"; default: return ""; } } static CHECK_RESULT int ssa_add_ins_form2(Ssa *self, SsaInstructionType ins_type, SsaRegister lhs, SsaRegister rhs, SsaRegister *result) { usize index; index = self->instructions.size; /* Overflow */ if(self->reg_counter + 1 < self->reg_counter) return -1; assert(result); return_if_error(buffer_append(&self->instructions, NULL, sizeof(u8) + sizeof(SsaRegister) + sizeof(SsaRegister) + sizeof(SsaRegister))); *result = self->reg_counter++; self->instructions.data[index + 0] = ins_type; *(SsaRegister*)&self->instructions.data[index + 1] = *result; *(SsaRegister*)&self->instructions.data[index + 3] = lhs; *(SsaRegister*)&self->instructions.data[index + 5] = rhs; amal_log_debug("r%u = r%u %s r%u", *result, lhs, binop_type_to_string(ins_type), rhs); return 0; } int ssa_ins_assign_inter(Ssa *self, SsaRegister dest, SsaNumber number) { SsaIntermediateIndex index; return_if_error(ssa_try_add_intermediate(self, number, &index)); amal_log_debug("r%u = i%u", dest, index); return ssa_add_ins_form1(self, SSA_ASSIGN_INTER, dest, index); } int ssa_ins_assign_string(Ssa *self, SsaRegister dest, BufferView str) { SsaStringIndex index; return_if_error(ssa_try_add_string(self, str, &index)); amal_log_debug("r%u = s%u", dest, index); return ssa_add_ins_form1(self, SSA_ASSIGN_STRING, dest, index); } int ssa_ins_assign_reg(Ssa *self, SsaRegister dest, SsaRegister src) { amal_log_debug("r%u = r%u", dest, src); return ssa_add_ins_form1(self, SSA_ASSIGN_INTER, dest, src); } int ssa_ins_binop(Ssa *self, SsaInstructionType binop_type, SsaRegister lhs, SsaRegister rhs, SsaRegister *result) { assert(binop_type >= SSA_ADD && binop_type <= SSA_DIV); return ssa_add_ins_form2(self, binop_type, lhs, rhs, result); } int ssa_ins_func_start(Ssa *self, u8 num_args, SsaFuncIndex *result) { usize index; index = self->instructions.size; /* Overflow */ if(self->func_counter + 1 < self->func_counter) return -1; return_if_error(buffer_append(&self->instructions, NULL, sizeof(u8) + sizeof(SsaFuncIndex) + sizeof(u8))); *result = self->func_counter++; self->instructions.data[index + 0] = SSA_FUNC_START; *(SsaFuncIndex*)&self->instructions.data[index + 1] = *result; self->instructions.data[index + 3] = num_args; amal_log_debug("FUNC_START f%u(%u)", *result, num_args); return 0; } int ssa_ins_func_end(Ssa *self) { u8 ins; ins = SSA_FUNC_END; amal_log_debug("FUNC_END"); return buffer_append(&self->instructions, &ins, 1); } int ssa_ins_push(Ssa *self, SsaRegister reg) { usize index; index = self->instructions.size; return_if_error(buffer_append(&self->instructions, NULL, sizeof(u8) + sizeof(SsaRegister))); self->instructions.data[index + 0] = SSA_PUSH; *(SsaRegister*)&self->instructions.data[index + 1] = reg; amal_log_debug("PUSH r%u", reg); return 0; } int ssa_ins_call(Ssa *self, SsaFuncIndex func, SsaRegister *result) { usize index; index = self->instructions.size; /* Overflow */ if(self->reg_counter + 1 < self->reg_counter) return -1; return_if_error(buffer_append(&self->instructions, NULL, sizeof(u8) + sizeof(SsaFuncIndex) + sizeof(SsaRegister))); *result = self->reg_counter++; self->instructions.data[index + 0] = SSA_CALL; *(SsaFuncIndex*)&self->instructions.data[index + 1] = func; *(SsaRegister*)&self->instructions.data[index + 3] = *result; amal_log_debug("r%u = CALL f%u", *result, func); return 0; } static CHECK_RESULT SsaRegister ast_generate_ssa(Ast *self, SsaCompilerContext *context); static CHECK_RESULT SsaRegister number_generate_ssa(Number *self, SsaCompilerContext *context) { SsaRegister reg; SsaNumber number; if(self->is_integer) { number = create_ssa_integer(self->value.integer); throw_if_error(ssa_get_unique_reg(&context->ssa, ®)); throw_if_error(ssa_ins_assign_inter(&context->ssa, reg, number)); } else { number = create_ssa_float(self->value.floating); throw_if_error(ssa_get_unique_reg(&context->ssa, ®)); throw_if_error(ssa_ins_assign_inter(&context->ssa, reg, number)); } return reg; } static CHECK_RESULT SsaRegister funcdecl_generate_ssa(FunctionDecl *self, SsaCompilerContext *context) { /* TODO: Implement */ throw_if_error(ssa_ins_func_start(&context->ssa, 0, &self->ssa_func_index)); scope_generate_ssa(&self->body, context); throw_if_error(ssa_ins_func_end(&context->ssa)); return 0; } static CHECK_RESULT SsaRegister funccall_generate_ssa(FunctionCall *self, SsaCompilerContext *context) { /* TODO: Implement */ Ast *ast; Ast *ast_end; SsaRegister reg; ast = buffer_start(&self->args); ast_end = buffer_end(&self->args); for(; ast != ast_end; ++ast) { SsaRegister arg_reg; arg_reg = ast_generate_ssa(ast, context); throw_if_error(ssa_ins_push(&context->ssa, arg_reg)); } /* TODO: Use real func index */ throw_if_error(ssa_ins_call(&context->ssa, 0, ®)); return reg; } static CHECK_RESULT SsaRegister structdecl_generate_ssa(StructDecl *self, SsaCompilerContext *context) { /* TODO: Implement */ scope_generate_ssa(&self->body, context); return 0; } static CHECK_RESULT SsaRegister structfield_generate_ssa(StructField *self, SsaCompilerContext *context) { /* TODO: Implement */ (void)self; (void)context; return 0; } static CHECK_RESULT SsaRegister lhs_generate_ssa(LhsExpr *self, SsaCompilerContext *context) { /* TODO: Implement */ SsaRegister rhs_reg; rhs_reg = ast_generate_ssa(&self->rhs_expr, context); /* TODO: Is this correct? */ return rhs_reg; } static CHECK_RESULT SsaRegister string_generate_ssa(String *self, SsaCompilerContext *context) { SsaRegister reg; throw_if_error(ssa_get_unique_reg(&context->ssa, ®)); throw_if_error(ssa_ins_assign_string(&context->ssa, reg, self->str)); return reg; } static CHECK_RESULT SsaRegister variable_generate_ssa(Variable *self, SsaCompilerContext *context) { /* TODO: Implement */ (void)self; (void)context; return 0; } static SsaInstructionType binop_type_to_ssa_type(BinopType binop_type) { switch(binop_type) { case BINOP_ADD: return SSA_ADD; case BINOP_SUB: return SSA_SUB; case BINOP_MUL: return SSA_MUL; case BINOP_DIV: return SSA_DIV; case BINOP_DOT: assert(bool_false && "TODO: Implement dot access"); return 0; } return 0; } static CHECK_RESULT SsaRegister binop_generate_ssa(Binop *self, SsaCompilerContext *context) { SsaRegister lhs_reg; SsaRegister rhs_reg; SsaRegister reg; lhs_reg = ast_generate_ssa(&self->lhs, context); rhs_reg = ast_generate_ssa(&self->rhs, context); throw_if_error(ssa_ins_binop(&context->ssa, binop_type_to_ssa_type(self->type), lhs_reg, rhs_reg, ®)); return reg; } CHECK_RESULT SsaRegister ast_generate_ssa(Ast *self, SsaCompilerContext *context) { assert(self->resolve_status == AST_RESOLVED); switch(self->type) { case AST_NONE: return 0; case AST_NUMBER: return number_generate_ssa(self->value.number, context); case AST_FUNCTION_DECL: return funcdecl_generate_ssa(self->value.func_decl, context); case AST_FUNCTION_CALL: return funccall_generate_ssa(self->value.func_call, context); case AST_STRUCT_DECL: return structdecl_generate_ssa(self->value.struct_decl, context); case AST_STRUCT_FIELD: return structfield_generate_ssa(self->value.struct_field, context); case AST_LHS: return lhs_generate_ssa(self->value.lhs_expr, context); case AST_IMPORT: /* TODO: When @import(...).data syntax is added, implement the generate ssa for it */ return 0; case AST_STRING: return string_generate_ssa(self->value.string, context); case AST_VARIABLE: return variable_generate_ssa(self->value.variable, context); case AST_BINOP: return binop_generate_ssa(self->value.binop, context); } return 0; } void scope_generate_ssa(Scope *self, SsaCompilerContext *context) { Ast *ast; Ast *ast_end; ast = buffer_start(&self->ast_objects); ast_end = buffer_end(&self->ast_objects); for(; ast != ast_end; ++ast) { ignore_result_int(ast_generate_ssa(ast, context)); } }