Extern functions

This commit is contained in:
nub31
2025-01-27 19:24:09 +01:00
parent cca7ff5b8e
commit ab3106434e
18 changed files with 239 additions and 138 deletions

4
.gitignore vendored
View File

@@ -33,4 +33,6 @@ Thumbs.db
Desktop.ini
.DS_Store
Nub.Lang/Nub.Lang/Output/out*
Nub.Lang/Nub.Lang/Output/*.o
Nub.Lang/Nub.Lang/Output/out
Nub.Lang/Nub.Lang/Output/out.asm

View File

@@ -13,7 +13,7 @@ public class BranchChecker
public void Check()
{
foreach (var funcDefinition in _definitions.OfType<FuncDefinitionNode>())
foreach (var funcDefinition in _definitions.OfType<LocalFuncDefinitionNode>())
{
if (funcDefinition.ReturnType.HasValue)
{

View File

@@ -11,6 +11,7 @@ public class Generator
private readonly SymbolTable _symbolTable;
private readonly StringBuilder _builder;
private readonly Dictionary<string, string> _strings;
private readonly HashSet<string> _externFuncDefinitions;
private int _stringIndex;
public Generator(IReadOnlyCollection<DefinitionNode> definitions)
@@ -18,8 +19,15 @@ public class Generator
_strings = [];
_definitions = definitions;
_builder = new StringBuilder();
_externFuncDefinitions = ["strcmp"];
_symbolTable = new SymbolTable(definitions.OfType<GlobalVariableDefinitionNode>().ToList());
foreach (var funcDefinitionNode in definitions.OfType<FuncDefinitionNode>())
foreach (var funcDefinitionNode in definitions.OfType<ExternFuncDefinitionNode>())
{
_symbolTable.DefineFunc(funcDefinitionNode);
_externFuncDefinitions.Add(_symbolTable.ResolveExternFunc(funcDefinitionNode.Name, funcDefinitionNode.Parameters.Select(p => p.Type).ToList()).StartLabel);
}
foreach (var funcDefinitionNode in definitions.OfType<LocalFuncDefinitionNode>())
{
_symbolTable.DefineFunc(funcDefinitionNode);
}
@@ -29,6 +37,11 @@ public class Generator
{
_builder.AppendLine("global _start");
foreach (var externFuncDefinition in _externFuncDefinitions)
{
_builder.AppendLine($"extern {externFuncDefinition}");
}
_builder.AppendLine();
_builder.AppendLine("section .bss");
foreach (var globalVariable in _definitions.OfType<GlobalVariableDefinitionNode>())
@@ -41,7 +54,7 @@ public class Generator
_builder.AppendLine("section .text");
_builder.AppendLine("_start:");
var main = _symbolTable.ResolveFunc(Entrypoint, []);
var main = _symbolTable.ResolveLocalFunc(Entrypoint, []);
_builder.AppendLine(" ; Initialize global variables");
foreach (var globalVariable in _definitions.OfType<GlobalVariableDefinitionNode>())
@@ -56,57 +69,18 @@ public class Generator
_builder.AppendLine($" call {main.StartLabel}");
_builder.AppendLine();
_builder.AppendLine(" ; Exit with status code 0");
_builder.AppendLine(main.ReturnType.HasValue
? " mov rdi, rax ; Exit with return value of entrypoint"
: " mov rdi, 0 ; Exit with default status code 0");
_builder.AppendLine(" mov rax, 60");
_builder.AppendLine(" mov rdi, 0");
_builder.AppendLine(" syscall");
foreach (var funcDefinition in _definitions.OfType<FuncDefinitionNode>())
foreach (var funcDefinition in _definitions.OfType<LocalFuncDefinitionNode>())
{
_builder.AppendLine();
GenerateFuncDefinition(funcDefinition);
}
_builder.AppendLine("""
; https://tuttlem.github.io/2013/01/08/strlen-implementation-in-nasm.html
strlen:
push rcx ; save and clear out counter
xor rcx, rcx
strlen_next:
cmp [rdi], byte 0 ; null byte yet?
jz strlen_null ; yes, get out
inc rcx ; char is ok, count it
inc rdi ; move to next char
jmp strlen_next ; process again
strlen_null:
mov rax, rcx ; rcx = the length (put in rax)
pop rcx ; restore rcx
ret ; get out
""");
_builder.AppendLine("""
strcmp:
xor rdx, rdx
strcmp_loop:
mov al, [rsi + rdx]
mov bl, [rdi + rdx]
inc rdx
cmp al, bl
jne strcmp_not_equal
cmp al, 0
je strcmp_equal
jmp strcmp_loop
strcmp_not_equal:
mov rax, 0
ret
strcmp_equal:
mov rax, 1
ret
""");
_builder.AppendLine();
_builder.AppendLine("section .data");
foreach (var str in _strings)
@@ -117,9 +91,10 @@ public class Generator
return _builder.ToString();
}
private void GenerateFuncDefinition(FuncDefinitionNode node)
private void GenerateFuncDefinition(LocalFuncDefinitionNode node)
{
var func = _symbolTable.ResolveFunc(node.Name, node.Parameters.Select(p => p.Type).ToList());
var func = _symbolTable.ResolveLocalFunc(node.Name, node.Parameters.Select(p => p.Type).ToList());
_builder.AppendLine($"; {node.ToString()}");
_builder.AppendLine($"{func.StartLabel}:");
_builder.AppendLine(" ; Set up stack frame");
@@ -154,7 +129,7 @@ public class Generator
_builder.AppendLine(" ret");
}
private void GenerateBlock(BlockNode block, Func func)
private void GenerateBlock(BlockNode block, LocalFunc func)
{
foreach (var statement in block.Statements)
{
@@ -162,7 +137,7 @@ public class Generator
}
}
private void GenerateStatement(StatementNode statement, Func func)
private void GenerateStatement(StatementNode statement, LocalFunc func)
{
switch (statement)
{
@@ -170,13 +145,13 @@ public class Generator
GenerateFuncCall(funcCallStatement.FuncCall, func);
break;
case ReturnNode @return:
GenerateReturn(func, @return);
GenerateReturn(@return, func);
break;
case SyscallStatementNode syscallStatement:
GenerateSyscall(syscallStatement.Syscall, func);
break;
case VariableAssignmentNode variableAssignment:
GenerateVariableAssignment(func, variableAssignment);
GenerateVariableAssignment(variableAssignment, func);
break;
case VariableReassignmentNode variableReassignment:
GenerateVariableReassignment(variableReassignment, func);
@@ -186,7 +161,7 @@ public class Generator
}
}
private void GenerateReturn(Func func, ReturnNode @return)
private void GenerateReturn(ReturnNode @return, LocalFunc func)
{
if (@return.Value.HasValue)
{
@@ -196,21 +171,21 @@ public class Generator
_builder.AppendLine($" jmp {func.EndLabel}");
}
private void GenerateVariableAssignment(Func func, VariableAssignmentNode variableAssignment)
private void GenerateVariableAssignment(VariableAssignmentNode variableAssignment, LocalFunc func)
{
var variable = func.ResolveLocalVariable(variableAssignment.Name);
GenerateExpression(variableAssignment.Value, func);
_builder.AppendLine($" mov [rbp - {variable.Offset}], rax");
}
private void GenerateVariableReassignment(VariableReassignmentNode variableReassignment, Func func)
private void GenerateVariableReassignment(VariableReassignmentNode variableReassignment, LocalFunc func)
{
var variable = func.ResolveLocalVariable(variableReassignment.Name);
GenerateExpression(variableReassignment.Value, func);
_builder.AppendLine($" mov [rbp - {variable.Offset}], rax");
}
private void GenerateExpression(ExpressionNode expression, Func func)
private void GenerateExpression(ExpressionNode expression, LocalFunc func)
{
switch (expression)
{
@@ -224,10 +199,7 @@ public class Generator
GenerateIdentifier(identifier, func);
break;
case LiteralNode literal:
GenerateLiteral(literal, func);
break;
case StrlenNode strlen:
GenerateStrlen(strlen, func);
GenerateLiteral(literal);
break;
case SyscallExpressionNode syscallExpression:
GenerateSyscall(syscallExpression.Syscall, func);
@@ -237,7 +209,7 @@ public class Generator
}
}
private void GenerateBinaryExpression(BinaryExpressionNode binaryExpression, Func func)
private void GenerateBinaryExpression(BinaryExpressionNode binaryExpression, LocalFunc func)
{
GenerateExpression(binaryExpression.Left, func);
_builder.AppendLine(" push rax");
@@ -396,7 +368,7 @@ public class Generator
}
}
private void GenerateIdentifier(IdentifierNode identifier, Func func)
private void GenerateIdentifier(IdentifierNode identifier, LocalFunc func)
{
var variable = func.ResolveVariable(identifier.Identifier);
@@ -417,7 +389,7 @@ public class Generator
}
}
private void GenerateLiteral(LiteralNode literal, Func func)
private void GenerateLiteral(LiteralNode literal)
{
switch (literal.Type)
{
@@ -459,14 +431,7 @@ public class Generator
}
}
private void GenerateStrlen(StrlenNode strlen, Func func)
{
GenerateExpression(strlen.String, func);
_builder.AppendLine(" mov rdi, rax");
_builder.AppendLine(" call strlen");
}
private void GenerateFuncCall(FuncCall funcCall, Func func)
private void GenerateFuncCall(FuncCall funcCall, LocalFunc func)
{
var symbol = _symbolTable.ResolveFunc(funcCall.Name, funcCall.Parameters.Select(p => p.Type).ToList());
string[] registers = ["rdi", "rsi", "rdx", "rcx", "r8", "r9"];
@@ -492,7 +457,7 @@ public class Generator
}
}
private void GenerateSyscall(Syscall syscall, Func func)
private void GenerateSyscall(Syscall syscall, LocalFunc func)
{
string[] registers = ["rax", "rdi", "rsi", "rdx", "r10", "r8", "r9"];

View File

@@ -5,7 +5,7 @@ namespace Nub.Lang.Generation;
public class SymbolTable
{
private readonly List<Func> _functions = [];
private readonly List<Func> _funcDefinitions = [];
private readonly List<GlobalVariable> _globalVariables = [];
private int _labelIndex;
@@ -19,26 +19,30 @@ public class SymbolTable
}
}
public void DefineFunc(FuncDefinitionNode funcDefinition)
public void DefineFunc(ExternFuncDefinitionNode externFuncDefinition)
{
var startLabel = $"func{++_labelIndex}";
var endLabel = $"endfunc{_labelIndex}";
var localVariables = ResolveFunctionVariables(funcDefinition);
_functions.Add(new Func(startLabel, endLabel, funcDefinition.Name, funcDefinition.Parameters, funcDefinition.ReturnType, _globalVariables.Concat<Variable>(localVariables.Variables).ToList(), localVariables.StackSize));
_funcDefinitions.Add(new ExternFunc(externFuncDefinition.Name, externFuncDefinition.Name, externFuncDefinition.Parameters, externFuncDefinition.ReturnType));
}
private (int StackSize, List<LocalVariable> Variables) ResolveFunctionVariables(FuncDefinitionNode funcDefinition)
public void DefineFunc(LocalFuncDefinitionNode localFuncDefinition)
{
var startLabel = $"func{++_labelIndex}";
var endLabel = $"func_end{_labelIndex}";
_funcDefinitions.Add(new LocalFunc(localFuncDefinition.Name, startLabel, endLabel, localFuncDefinition.Parameters, localFuncDefinition.ReturnType, _globalVariables.Concat<Variable>(ResolveFuncVariables(localFuncDefinition)).ToList()));
}
private static List<LocalVariable> ResolveFuncVariables(LocalFuncDefinitionNode localFuncDefinition)
{
var offset = 0;
List<LocalVariable> variables = [];
foreach (var parameter in funcDefinition.Parameters)
foreach (var parameter in localFuncDefinition.Parameters)
{
offset += 8;
variables.Add(new LocalVariable(parameter.Name, parameter.Type, offset));
}
foreach (var statement in funcDefinition.Body.Statements)
foreach (var statement in localFuncDefinition.Body.Statements)
{
if (statement is VariableAssignmentNode variableAssignment)
{
@@ -47,12 +51,12 @@ public class SymbolTable
}
}
return (offset, variables);
return variables;
}
public Func ResolveFunc(string name, IReadOnlyCollection<Type> parameterTypes)
{
var func = _functions.FirstOrDefault(f => f.Name == name && f.Parameters.Count == parameterTypes.Count && f.Parameters.Where((p, i) => p.Type == parameterTypes.ElementAt(i)).Count() == parameterTypes.Count);
var func = _funcDefinitions.FirstOrDefault(f => f.Name == name && f.Parameters.Count == parameterTypes.Count && f.Parameters.Where((p, i) => p.Type == parameterTypes.ElementAt(i)).Count() == parameterTypes.Count);
if (func == null)
{
throw new Exception($"Func {name}({string.Join(", ", parameterTypes)}) is not defined");
@@ -61,6 +65,26 @@ public class SymbolTable
return func;
}
public LocalFunc ResolveLocalFunc(string name, IReadOnlyCollection<Type> parameterTypes)
{
var func = ResolveFunc(name, parameterTypes);
if (func is not LocalFunc localFunc)
{
throw new Exception($"Func {func} is not a local func");
}
return localFunc;
}
public ExternFunc ResolveExternFunc(string name, IReadOnlyCollection<Type> parameterTypes)
{
var func = ResolveFunc(name, parameterTypes);
if (func is not ExternFunc externFunc)
{
throw new Exception($"Func {func} is not an extern func");
}
return externFunc;
}
public GlobalVariable ResolveGlobalVariable(string name)
{
var variable = _globalVariables.FirstOrDefault(v => v.Name == name);
@@ -89,15 +113,40 @@ public class GlobalVariable(string name, Type type, string identifier) : Variabl
public string Identifier { get; } = identifier;
}
public class Func(string startLabel, string endLabel, string name, IReadOnlyCollection<FuncParameter> parameters, Optional<Type> returnType, IReadOnlyCollection<Variable> variables, int stackAllocation)
public abstract class Func
{
public string StartLabel { get; } = startLabel;
public string EndLabel { get; } = endLabel;
public string Name { get; } = name;
public IReadOnlyCollection<FuncParameter> Parameters { get; } = parameters;
public Optional<Type> ReturnType { get; } = returnType;
public IReadOnlyCollection<Variable> Variables { get; } = variables;
public int StackAllocation { get; } = stackAllocation;
public Func(string name, string startLabel, IReadOnlyCollection<FuncParameter> parameters, Optional<Type> returnType)
{
Name = name;
Parameters = parameters;
ReturnType = returnType;
StartLabel = startLabel;
}
public string Name { get; }
public string StartLabel { get; }
public IReadOnlyCollection<FuncParameter> Parameters { get; }
public Optional<Type> ReturnType { get; }
}
public class ExternFunc : Func
{
public ExternFunc(string name, string startLabel, IReadOnlyCollection<FuncParameter> parameters, Optional<Type> returnType) : base(name, startLabel, parameters, returnType)
{
}
}
public class LocalFunc : Func
{
public LocalFunc(string name, string startLabel, string endLabel, IReadOnlyCollection<FuncParameter> parameters, Optional<Type> returnType, IReadOnlyCollection<Variable> variables) : base(name, startLabel, parameters, returnType)
{
EndLabel = endLabel;
Variables = variables;
}
public string EndLabel { get; }
public IReadOnlyCollection<Variable> Variables { get; }
public int StackAllocation => Variables.OfType<LocalVariable>().Sum(variable => variable.Offset);
public Variable ResolveVariable(string name)
{

View File

@@ -1,3 +1,15 @@
let SYS_WRITE = 1;
let STD_IN = 0;
let STD_OUT = 1;
let STD_ERR = 2;
func main() {
syscall(60, 5 + 3 * 5);
print("test\n");
}
func print(msg: String) {
syscall(SYS_WRITE, STD_OUT, msg, strlen(msg));
}
extern func strlen(msg: String): int64;
extern func strcmp(a: String, b: String): bool;

View File

@@ -7,6 +7,7 @@ public class Lexer
private static readonly Dictionary<string, Symbol> Keywords = new()
{
["func"] = Symbol.Func,
["extern"] = Symbol.Extern,
["return"] = Symbol.Return,
["let"] = Symbol.Let,
};

View File

@@ -8,6 +8,7 @@ public class SymbolToken(Symbol symbol) : Token
public enum Symbol
{
Whitespace,
Extern,
Func,
Return,
Let,

View File

@@ -0,0 +1,6 @@
#!/bin/sh
nasm -g -felf64 out.asm -o out.o
nasm -g -felf64 core/strlen.asm -o strlen.o
nasm -g -felf64 core/strcmp.asm -o strcmp.o
ld -o out out.o strlen.o strcmp.o

View File

@@ -0,0 +1,20 @@
global strlen
section .text
strcmp:
xor rdx, rdx
.loop:
mov al, [rsi + rdx]
mov bl, [rdi + rdx]
inc rdx
cmp al, bl
jne .not_equal
cmp al, 0
je .equal
jmp .loop
.not_equal:
mov rax, 0
ret
.equal:
mov rax, 1
ret

View File

@@ -0,0 +1,13 @@
global strlen
section .text
strlen:
xor rax, rax
.loop:
cmp byte [rdi], 0
jz .done
inc rax
inc rdi
jmp .loop
.done:
ret

View File

@@ -1,9 +1,3 @@
#!/bin/bash
nasm -g -felf64 out.asm -o out.o
ld out.o -o out
#!/bin/sh
./build.sh
gdb -tui out
rm out.o
rm out

View File

@@ -1,10 +1,4 @@
#!/bin/bash
nasm -g -felf64 out.asm -o out.o
ld out.o -o out
#!/bin/sh
./build.sh
./out
echo "Process exited with status code $?"
rm out.o
rm out

View File

@@ -0,0 +1,12 @@
using Nub.Core;
namespace Nub.Lang.Parsing;
public class ExternFuncDefinitionNode(string name, IReadOnlyCollection<FuncParameter> parameters, Optional<Type> returnType) : DefinitionNode
{
public string Name { get; } = name;
public IReadOnlyCollection<FuncParameter> Parameters { get; } = parameters;
public Optional<Type> ReturnType { get; } = returnType;
public override string ToString() => $"{Name}({string.Join(", ", Parameters.Select(p => p.ToString()))}){(ReturnType.HasValue ? ": " + ReturnType.Value : "")}";
}

View File

@@ -2,7 +2,7 @@
namespace Nub.Lang.Parsing;
public class FuncDefinitionNode(string name, IReadOnlyCollection<FuncParameter> parameters, BlockNode body, Optional<Type> returnType) : DefinitionNode
public class LocalFuncDefinitionNode(string name, IReadOnlyCollection<FuncParameter> parameters, BlockNode body, Optional<Type> returnType) : DefinitionNode
{
public string Name { get; } = name;
public IReadOnlyCollection<FuncParameter> Parameters { get; } = parameters;

View File

@@ -33,6 +33,7 @@ public class Parser
{
Symbol.Let => ParseGlobalVariableDefinition(),
Symbol.Func => ParseFuncDefinition(),
Symbol.Extern => ParseExternFuncDefinition(),
_ => throw new Exception("Unexpected symbol: " + keyword.Symbol)
};
}
@@ -47,7 +48,7 @@ public class Parser
return new GlobalVariableDefinitionNode(name.Value, value);
}
private FuncDefinitionNode ParseFuncDefinition()
private LocalFuncDefinitionNode ParseFuncDefinition()
{
var name = ExpectIdentifier();
List<FuncParameter> parameters = [];
@@ -69,7 +70,33 @@ public class Parser
var body = ParseBlock();
return new FuncDefinitionNode(name.Value, parameters, body, returnType);
return new LocalFuncDefinitionNode(name.Value, parameters, body, returnType);
}
private ExternFuncDefinitionNode ParseExternFuncDefinition()
{
ExpectSymbol(Symbol.Func);
var name = ExpectIdentifier();
List<FuncParameter> parameters = [];
ExpectSymbol(Symbol.OpenParen);
if (!TryExpectSymbol(Symbol.CloseParen))
{
while (!TryExpectSymbol(Symbol.CloseParen))
{
parameters.Add(ParseFuncParameter());
TryExpectSymbol(Symbol.Comma);
}
}
var returnType = Optional<Type>.Empty();
if (TryExpectSymbol(Symbol.Colon))
{
returnType = ParseType();
}
ExpectSymbol(Symbol.Semicolon);
return new ExternFuncDefinitionNode(name.Value, parameters, returnType);
}
private FuncParameter ParseFuncParameter()
@@ -268,11 +295,6 @@ public class Parser
return new SyscallExpressionNode(new Syscall(parameters));
}
if (identifier.Value == "strlen" && parameters.Count == 1)
{
return new StrlenNode(parameters[0]);
}
return new FuncCallExpressionNode(new FuncCall(identifier.Value, parameters));
}

View File

@@ -1,6 +0,0 @@
namespace Nub.Lang.Parsing;
public class StrlenNode(ExpressionNode @string) : ExpressionNode
{
public ExpressionNode String { get; } = @string;
}

View File

@@ -18,6 +18,7 @@ public record PrimitiveType : Type
"bool" => PrimitiveTypeKind.Bool,
"char" => PrimitiveTypeKind.Char,
"int64" => PrimitiveTypeKind.Int64,
"int32" => PrimitiveTypeKind.Int32,
_ => throw new ArgumentOutOfRangeException(nameof(value), value, null)
};

View File

@@ -1,16 +1,35 @@
using Nub.Lang.Parsing;
using Nub.Core;
using Nub.Lang.Parsing;
namespace Nub.Lang.Typing;
public class Func(string name, IReadOnlyCollection<FuncParameter> parameters, Optional<BlockNode> body, Optional<Type> returnType)
{
public string Name { get; } = name;
public IReadOnlyCollection<FuncParameter> Parameters { get; } = parameters;
public Optional<BlockNode> Body { get; } = body;
public Optional<Type> ReturnType { get; } = returnType;
}
public class ExpressionTyper
{
private readonly IReadOnlyCollection<FuncDefinitionNode> _functions;
private readonly IReadOnlyCollection<Func> _functions;
private readonly IReadOnlyCollection<GlobalVariableDefinitionNode> _variableDefinitions;
private readonly Stack<Variable> _variables;
public ExpressionTyper(IReadOnlyCollection<DefinitionNode> definitions)
{
_functions = definitions.OfType<FuncDefinitionNode>().ToList();
var functions = definitions
.OfType<LocalFuncDefinitionNode>()
.Select(f => new Func(f.Name, f.Parameters, f.Body, f.ReturnType))
.ToList();
var externFunctions = definitions
.OfType<ExternFuncDefinitionNode>()
.Select(f => new Func(f.Name, f.Parameters, Optional<BlockNode>.Empty(), f.ReturnType))
.ToList();
_functions = functions.Concat(externFunctions).ToList();
_variableDefinitions = definitions.OfType<GlobalVariableDefinitionNode>().ToList();
_variables = new Stack<Variable>();
}
@@ -31,8 +50,12 @@ public class ExpressionTyper
{
_variables.Push(new Variable(parameter.Name, parameter.Type));
}
PopulateBlock(function.Body);
for (var i = 0; i < function.Parameters.Count(); i++)
if (function.Body.HasValue)
{
PopulateBlock(function.Body.Value);
}
for (var i = 0; i < function.Parameters.Count; i++)
{
_variables.Pop();
}
@@ -127,9 +150,6 @@ public class ExpressionTyper
case LiteralNode literal:
PopulateLiteral(literal);
break;
case StrlenNode strlen:
PopulateStrlen(strlen);
break;
case SyscallExpressionNode syscall:
PopulateSyscallExpression(syscall);
break;
@@ -203,11 +223,6 @@ public class ExpressionTyper
literal.Type = literal.LiteralType;
}
private static void PopulateStrlen(StrlenNode strlen)
{
strlen.Type = new PrimitiveType(PrimitiveTypeKind.Int64);
}
private void PopulateSyscallExpression(SyscallExpressionNode syscall)
{
foreach (var parameter in syscall.Syscall.Parameters)