diff --git a/main.c b/main.c index 85f23d5..4ca9bad 100644 --- a/main.c +++ b/main.c @@ -54,11 +54,26 @@ enum { #define IS_FLOAT(c) ((c >= '0' && c <= '9') || c == '.') #define IS_ALPHA(c) ((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')) +typedef struct Function { + const char *name; + real (*func)(real *args); + size_t n_args; +} Function; + +#define FUNCTIONS_CAP 256 +Function functions[FUNCTIONS_CAP]; +size_t functions_size = 0; + void push_tok(Tok t) { if (toks_size+1 < TOKS_CAP) toks[toks_size++] = t; } +void add_func(const char *name, real (*func)(real *args), size_t n_args) { + if (functions_size+1 < FUNCTIONS_CAP) + functions[functions_size++] = (Function){.name = name, .func = func, .n_args = n_args}; +} + void tokenize(char *expr) { push_tok((Tok){.kind = TokOp, .Char = '('}); @@ -222,22 +237,23 @@ real eval(Tok *t) { t -= 2; real outer_res; - if (strcmp(t[1].Str, "sqrt") == 0) { - if (arg_results_size != 1) { - fprintf(stderr, "Error: function sqrt() requires exactly 1 argument\n"); - exit(1); + bool func_found = false; + for (size_t i = 0; i < functions_size; i++) { + if (strcmp(t[1].Str, functions[i].name) == 0) { + func_found = true; + if (arg_results_size != functions[i].n_args) { + const char *plural = functions[i].n_args == 1 ? "" : "s"; + fprintf(stderr, "Error: function %s() requires exactly 1 argument%s\n", functions[i].name, plural); + exit(1); + } + outer_res = functions[i].func(arg_results); } - outer_res = sqrt(arg_results[0]); - } else if (strcmp(t[1].Str, "pow") == 0) { - if (arg_results_size != 2) { - fprintf(stderr, "Error: function pow() requires exactly 2 arguments\n"); - exit(1); - } - outer_res = pow(arg_results[0], arg_results[1]); - } else { - fprintf(stderr, "Error: unknown function name: %s\n", t[1].Str); + } + if (!func_found) { + fprintf(stderr, "Error: unknown function: %s()\n", t[1].Str); exit(1); } + t[1].kind = TokNum; t[1].Num = outer_res; } @@ -289,11 +305,24 @@ void cleanup() { } } +real fn_sqrt(real *args) { return sqrt(args[0]); } +real fn_pow(real *args) { return pow(args[0], args[1]); } +real fn_mod(real *args) { return fmod(args[0], args[1]); } +real fn_round(real *args) { return round(args[0]); } +real fn_floor(real *args) { return floor(args[0]); } +real fn_ceil(real *args) { return ceil(args[0]); } + int main(int argc, char **argv) { if (argc != 2 || strcmp(argv[1], "-h") == 0 || strcmp(argv[1], "--help") == 0) { fprintf(stderr, "Usage: ./exp \"\"\n"); exit(1); } + add_func("sqrt", fn_sqrt, 1); + add_func("pow", fn_pow, 2); + add_func("mod", fn_mod, 2); + add_func("round", fn_round, 1); + add_func("floor", fn_floor, 1); + add_func("ceil", fn_ceil, 1); tokenize(argv[1]); print_toks(); real res = eval(toks);