If statements

This commit is contained in:
nub31
2025-01-27 20:35:03 +01:00
parent 5d5781dbd7
commit 37ee8fafd3
9 changed files with 186 additions and 17 deletions

View File

@@ -13,6 +13,7 @@ public class Generator
private readonly Dictionary<string, string> _strings;
private readonly HashSet<string> _externFuncDefinitions;
private int _stringIndex;
private int _labelIndex;
public Generator(IReadOnlyCollection<DefinitionNode> definitions)
{
@@ -144,6 +145,9 @@ public class Generator
case FuncCallStatementNode funcCallStatement:
GenerateFuncCall(funcCallStatement.FuncCall, func);
break;
case IfNode ifStatement:
GenerateIf(ifStatement, func);
break;
case ReturnNode @return:
GenerateReturn(@return, func);
break;
@@ -161,6 +165,33 @@ public class Generator
}
}
private void GenerateIf(IfNode ifStatement, LocalFunc func)
{
var endLabel = CreateLabel();
GenerateIf(ifStatement, endLabel, func);
_builder.AppendLine($"{endLabel}:");
}
private void GenerateIf(IfNode ifStatement, string endLabel, LocalFunc func)
{
var nextLabel = CreateLabel();
GenerateExpression(ifStatement.Condition, func);
_builder.AppendLine(" cmp rax, 0");
_builder.AppendLine($" je {nextLabel}");
GenerateBlock(ifStatement.Body, func);
_builder.AppendLine($" jmp {endLabel}");
_builder.AppendLine($"{nextLabel}:");
if (ifStatement.Else.HasValue)
{
ifStatement.Else.Value.Match
(
elseIfStatement => GenerateIf(elseIfStatement, endLabel, func),
elseStatement => GenerateBlock(elseStatement, func)
);
}
}
private void GenerateReturn(ReturnNode @return, LocalFunc func)
{
if (@return.Value.HasValue)
@@ -474,4 +505,9 @@ public class Generator
_builder.AppendLine(" syscall");
}
private string CreateLabel()
{
return $"label{++_labelIndex}";
}
}

View File

@@ -21,11 +21,37 @@ public class SymbolTable
public void DefineFunc(ExternFuncDefinitionNode externFuncDefinition)
{
var existing = _funcDefinitions
.FirstOrDefault(f => f
.SignatureMatches
(
externFuncDefinition.Name,
externFuncDefinition.Parameters.Select(p => p.Type).ToList()
));
if (existing != null)
{
throw new Exception($"Func {existing} is already defined");
}
_funcDefinitions.Add(new ExternFunc(externFuncDefinition.Name, externFuncDefinition.Name, externFuncDefinition.Parameters, externFuncDefinition.ReturnType));
}
public void DefineFunc(LocalFuncDefinitionNode localFuncDefinition)
{
var existing = _funcDefinitions
.FirstOrDefault(f => f
.SignatureMatches
(
localFuncDefinition.Name,
localFuncDefinition.Parameters.Select(p => p.Type).ToList()
));
if (existing != null)
{
throw new Exception($"Func {existing} is already defined");
}
var startLabel = $"func{++_labelIndex}";
var endLabel = $"func_end{_labelIndex}";
_funcDefinitions.Add(new LocalFunc(localFuncDefinition.Name, startLabel, endLabel, localFuncDefinition.Parameters, localFuncDefinition.ReturnType, _globalVariables.Concat<Variable>(ResolveFuncVariables(localFuncDefinition)).ToList()));
@@ -56,7 +82,7 @@ public class SymbolTable
public Func ResolveFunc(string name, IReadOnlyCollection<Type> parameterTypes)
{
var func = _funcDefinitions.FirstOrDefault(f => f.Name == name && f.Parameters.Count == parameterTypes.Count && f.Parameters.Where((p, i) => p.Type == parameterTypes.ElementAt(i)).Count() == parameterTypes.Count);
var func = _funcDefinitions.FirstOrDefault(f => f.SignatureMatches(name, parameterTypes));
if (func == null)
{
throw new Exception($"Func {name}({string.Join(", ", parameterTypes)}) is not defined");
@@ -115,7 +141,7 @@ public class GlobalVariable(string name, Type type, string identifier) : Variabl
public abstract class Func
{
public Func(string name, string startLabel, IReadOnlyCollection<FuncParameter> parameters, Optional<Type> returnType)
protected Func(string name, string startLabel, IReadOnlyCollection<FuncParameter> parameters, Optional<Type> returnType)
{
Name = name;
Parameters = parameters;
@@ -127,6 +153,13 @@ public abstract class Func
public string StartLabel { get; }
public IReadOnlyCollection<FuncParameter> Parameters { get; }
public Optional<Type> ReturnType { get; }
public bool SignatureMatches(string name, IReadOnlyCollection<Type> parameterTypes)
{
return Name == name
&& Parameters.Count == parameterTypes.Count
&& Parameters.Where((p, i) => p.Type == parameterTypes.ElementAt(i)).Count() == parameterTypes.Count;
}
}
public class ExternFunc : Func

View File

@@ -4,12 +4,45 @@ let STD_OUT = 1;
let STD_ERR = 2;
func main() {
print("test\n");
println("test");
println(true);
if true {
println("1");
} else if false {
println("2");
} else if true {
println("3");
} else {
println("4");
}
}
func print(msg: String) {
syscall(SYS_WRITE, STD_OUT, msg, strlen(msg));
}
func print(value: bool) {
if value {
print("true");
} else {
print("false");
}
}
func println() {
print("\n");
}
func println(msg: String) {
print(msg);
println();
}
func println(value: bool) {
print(value);
println();
}
extern func strlen(msg: String): int64;
extern func strcmp(a: String, b: String): bool;

View File

@@ -10,6 +10,8 @@ public class Lexer
["extern"] = Symbol.Extern,
["return"] = Symbol.Return,
["let"] = Symbol.Let,
["if"] = Symbol.If,
["else"] = Symbol.Else,
};
private static readonly Dictionary<char[], Symbol> Chians = new()
@@ -81,6 +83,11 @@ public class Lexer
return new SymbolToken(keywordSymbol);
}
if (buffer is "true" or "false")
{
return new LiteralToken(new PrimitiveType(PrimitiveTypeKind.Bool), buffer);
}
return new IdentifierToken(buffer);
}

View File

@@ -12,6 +12,8 @@ public enum Symbol
Func,
Return,
Let,
If,
Else,
Semicolon,
Colon,
OpenParen,

View File

@@ -8,7 +8,7 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Nub.Core" Version="1.0.0" />
<PackageReference Include="Nub.Core" Version="1.0.1" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,10 @@
using Nub.Core;
namespace Nub.Lang.Parsing;
public class IfNode(ExpressionNode condition, BlockNode body, Optional<Variant<IfNode, BlockNode>> @else) : StatementNode
{
public ExpressionNode Condition { get; } = condition;
public BlockNode Body { get; } = body;
public Optional<Variant<IfNode, BlockNode>> Else { get; } = @else;
}

View File

@@ -154,22 +154,15 @@ public class Parser
{
case Symbol.Return:
{
var value = Optional<ExpressionNode>.Empty();
if (!TryExpectSymbol(Symbol.Semicolon))
{
value = ParseExpression();
ExpectSymbol(Symbol.Semicolon);
}
return new ReturnNode(value);
return ParseReturn();
}
case Symbol.Let:
{
var name = ExpectIdentifier().Value;
ExpectSymbol(Symbol.Assign);
var value = ParseExpression();
ExpectSymbol(Symbol.Semicolon);
return new VariableAssignmentNode(name, value);
return ParseVariableAssignment();
}
case Symbol.If:
{
return ParseIf();
}
default:
{
@@ -184,6 +177,44 @@ public class Parser
}
}
private ReturnNode ParseReturn()
{
var value = Optional<ExpressionNode>.Empty();
if (!TryExpectSymbol(Symbol.Semicolon))
{
value = ParseExpression();
ExpectSymbol(Symbol.Semicolon);
}
return new ReturnNode(value);
}
private VariableAssignmentNode ParseVariableAssignment()
{
var name = ExpectIdentifier().Value;
ExpectSymbol(Symbol.Assign);
var value = ParseExpression();
ExpectSymbol(Symbol.Semicolon);
return new VariableAssignmentNode(name, value);
}
private IfNode ParseIf()
{
var condition = ParseExpression();
var body = ParseBlock();
var elseStatement = Optional<Variant<IfNode, BlockNode>>.Empty();
if (TryExpectSymbol(Symbol.Else))
{
elseStatement = TryExpectSymbol(Symbol.If)
? (Variant<IfNode, BlockNode>)ParseIf()
: (Variant<IfNode, BlockNode>)ParseBlock();
}
return new IfNode(condition, body, elseStatement);
}
private ExpressionNode ParseExpression(int precedence = 0)
{
var left = ParsePrimaryExpression();

View File

@@ -82,6 +82,9 @@ public class ExpressionTyper
case FuncCallStatementNode funcCall:
PopulateFuncCallStatement(funcCall);
break;
case IfNode ifStatement:
PopulateIf(ifStatement);
break;
case ReturnNode returnNode:
PopulateReturn(returnNode);
break;
@@ -107,6 +110,20 @@ public class ExpressionTyper
}
}
private void PopulateIf(IfNode ifStatement)
{
PopulateExpression(ifStatement.Condition);
PopulateBlock(ifStatement.Body);
if (ifStatement.Else.HasValue)
{
ifStatement.Else.Value.Match
(
PopulateIf,
PopulateBlock
);
}
}
private void PopulateSyscallStatement(SyscallStatementNode syscall)
{
foreach (var parameter in syscall.Syscall.Parameters)