If statements

This commit is contained in:
nub31
2025-01-27 20:35:03 +01:00
parent 5d5781dbd7
commit 37ee8fafd3
9 changed files with 186 additions and 17 deletions

View File

@@ -13,6 +13,7 @@ public class Generator
private readonly Dictionary<string, string> _strings; private readonly Dictionary<string, string> _strings;
private readonly HashSet<string> _externFuncDefinitions; private readonly HashSet<string> _externFuncDefinitions;
private int _stringIndex; private int _stringIndex;
private int _labelIndex;
public Generator(IReadOnlyCollection<DefinitionNode> definitions) public Generator(IReadOnlyCollection<DefinitionNode> definitions)
{ {
@@ -144,6 +145,9 @@ public class Generator
case FuncCallStatementNode funcCallStatement: case FuncCallStatementNode funcCallStatement:
GenerateFuncCall(funcCallStatement.FuncCall, func); GenerateFuncCall(funcCallStatement.FuncCall, func);
break; break;
case IfNode ifStatement:
GenerateIf(ifStatement, func);
break;
case ReturnNode @return: case ReturnNode @return:
GenerateReturn(@return, func); GenerateReturn(@return, func);
break; break;
@@ -161,6 +165,33 @@ public class Generator
} }
} }
private void GenerateIf(IfNode ifStatement, LocalFunc func)
{
var endLabel = CreateLabel();
GenerateIf(ifStatement, endLabel, func);
_builder.AppendLine($"{endLabel}:");
}
private void GenerateIf(IfNode ifStatement, string endLabel, LocalFunc func)
{
var nextLabel = CreateLabel();
GenerateExpression(ifStatement.Condition, func);
_builder.AppendLine(" cmp rax, 0");
_builder.AppendLine($" je {nextLabel}");
GenerateBlock(ifStatement.Body, func);
_builder.AppendLine($" jmp {endLabel}");
_builder.AppendLine($"{nextLabel}:");
if (ifStatement.Else.HasValue)
{
ifStatement.Else.Value.Match
(
elseIfStatement => GenerateIf(elseIfStatement, endLabel, func),
elseStatement => GenerateBlock(elseStatement, func)
);
}
}
private void GenerateReturn(ReturnNode @return, LocalFunc func) private void GenerateReturn(ReturnNode @return, LocalFunc func)
{ {
if (@return.Value.HasValue) if (@return.Value.HasValue)
@@ -474,4 +505,9 @@ public class Generator
_builder.AppendLine(" syscall"); _builder.AppendLine(" syscall");
} }
private string CreateLabel()
{
return $"label{++_labelIndex}";
}
} }

View File

@@ -21,11 +21,37 @@ public class SymbolTable
public void DefineFunc(ExternFuncDefinitionNode externFuncDefinition) public void DefineFunc(ExternFuncDefinitionNode externFuncDefinition)
{ {
var existing = _funcDefinitions
.FirstOrDefault(f => f
.SignatureMatches
(
externFuncDefinition.Name,
externFuncDefinition.Parameters.Select(p => p.Type).ToList()
));
if (existing != null)
{
throw new Exception($"Func {existing} is already defined");
}
_funcDefinitions.Add(new ExternFunc(externFuncDefinition.Name, externFuncDefinition.Name, externFuncDefinition.Parameters, externFuncDefinition.ReturnType)); _funcDefinitions.Add(new ExternFunc(externFuncDefinition.Name, externFuncDefinition.Name, externFuncDefinition.Parameters, externFuncDefinition.ReturnType));
} }
public void DefineFunc(LocalFuncDefinitionNode localFuncDefinition) public void DefineFunc(LocalFuncDefinitionNode localFuncDefinition)
{ {
var existing = _funcDefinitions
.FirstOrDefault(f => f
.SignatureMatches
(
localFuncDefinition.Name,
localFuncDefinition.Parameters.Select(p => p.Type).ToList()
));
if (existing != null)
{
throw new Exception($"Func {existing} is already defined");
}
var startLabel = $"func{++_labelIndex}"; var startLabel = $"func{++_labelIndex}";
var endLabel = $"func_end{_labelIndex}"; var endLabel = $"func_end{_labelIndex}";
_funcDefinitions.Add(new LocalFunc(localFuncDefinition.Name, startLabel, endLabel, localFuncDefinition.Parameters, localFuncDefinition.ReturnType, _globalVariables.Concat<Variable>(ResolveFuncVariables(localFuncDefinition)).ToList())); _funcDefinitions.Add(new LocalFunc(localFuncDefinition.Name, startLabel, endLabel, localFuncDefinition.Parameters, localFuncDefinition.ReturnType, _globalVariables.Concat<Variable>(ResolveFuncVariables(localFuncDefinition)).ToList()));
@@ -56,7 +82,7 @@ public class SymbolTable
public Func ResolveFunc(string name, IReadOnlyCollection<Type> parameterTypes) public Func ResolveFunc(string name, IReadOnlyCollection<Type> parameterTypes)
{ {
var func = _funcDefinitions.FirstOrDefault(f => f.Name == name && f.Parameters.Count == parameterTypes.Count && f.Parameters.Where((p, i) => p.Type == parameterTypes.ElementAt(i)).Count() == parameterTypes.Count); var func = _funcDefinitions.FirstOrDefault(f => f.SignatureMatches(name, parameterTypes));
if (func == null) if (func == null)
{ {
throw new Exception($"Func {name}({string.Join(", ", parameterTypes)}) is not defined"); throw new Exception($"Func {name}({string.Join(", ", parameterTypes)}) is not defined");
@@ -115,7 +141,7 @@ public class GlobalVariable(string name, Type type, string identifier) : Variabl
public abstract class Func public abstract class Func
{ {
public Func(string name, string startLabel, IReadOnlyCollection<FuncParameter> parameters, Optional<Type> returnType) protected Func(string name, string startLabel, IReadOnlyCollection<FuncParameter> parameters, Optional<Type> returnType)
{ {
Name = name; Name = name;
Parameters = parameters; Parameters = parameters;
@@ -127,6 +153,13 @@ public abstract class Func
public string StartLabel { get; } public string StartLabel { get; }
public IReadOnlyCollection<FuncParameter> Parameters { get; } public IReadOnlyCollection<FuncParameter> Parameters { get; }
public Optional<Type> ReturnType { get; } public Optional<Type> ReturnType { get; }
public bool SignatureMatches(string name, IReadOnlyCollection<Type> parameterTypes)
{
return Name == name
&& Parameters.Count == parameterTypes.Count
&& Parameters.Where((p, i) => p.Type == parameterTypes.ElementAt(i)).Count() == parameterTypes.Count;
}
} }
public class ExternFunc : Func public class ExternFunc : Func

View File

@@ -4,12 +4,45 @@ let STD_OUT = 1;
let STD_ERR = 2; let STD_ERR = 2;
func main() { func main() {
print("test\n"); println("test");
println(true);
if true {
println("1");
} else if false {
println("2");
} else if true {
println("3");
} else {
println("4");
}
} }
func print(msg: String) { func print(msg: String) {
syscall(SYS_WRITE, STD_OUT, msg, strlen(msg)); syscall(SYS_WRITE, STD_OUT, msg, strlen(msg));
} }
func print(value: bool) {
if value {
print("true");
} else {
print("false");
}
}
func println() {
print("\n");
}
func println(msg: String) {
print(msg);
println();
}
func println(value: bool) {
print(value);
println();
}
extern func strlen(msg: String): int64; extern func strlen(msg: String): int64;
extern func strcmp(a: String, b: String): bool; extern func strcmp(a: String, b: String): bool;

View File

@@ -10,6 +10,8 @@ public class Lexer
["extern"] = Symbol.Extern, ["extern"] = Symbol.Extern,
["return"] = Symbol.Return, ["return"] = Symbol.Return,
["let"] = Symbol.Let, ["let"] = Symbol.Let,
["if"] = Symbol.If,
["else"] = Symbol.Else,
}; };
private static readonly Dictionary<char[], Symbol> Chians = new() private static readonly Dictionary<char[], Symbol> Chians = new()
@@ -80,6 +82,11 @@ public class Lexer
{ {
return new SymbolToken(keywordSymbol); return new SymbolToken(keywordSymbol);
} }
if (buffer is "true" or "false")
{
return new LiteralToken(new PrimitiveType(PrimitiveTypeKind.Bool), buffer);
}
return new IdentifierToken(buffer); return new IdentifierToken(buffer);
} }

View File

@@ -12,6 +12,8 @@ public enum Symbol
Func, Func,
Return, Return,
Let, Let,
If,
Else,
Semicolon, Semicolon,
Colon, Colon,
OpenParen, OpenParen,

View File

@@ -8,7 +8,7 @@
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="Nub.Core" Version="1.0.0" /> <PackageReference Include="Nub.Core" Version="1.0.1" />
</ItemGroup> </ItemGroup>
</Project> </Project>

View File

@@ -0,0 +1,10 @@
using Nub.Core;
namespace Nub.Lang.Parsing;
public class IfNode(ExpressionNode condition, BlockNode body, Optional<Variant<IfNode, BlockNode>> @else) : StatementNode
{
public ExpressionNode Condition { get; } = condition;
public BlockNode Body { get; } = body;
public Optional<Variant<IfNode, BlockNode>> Else { get; } = @else;
}

View File

@@ -154,22 +154,15 @@ public class Parser
{ {
case Symbol.Return: case Symbol.Return:
{ {
var value = Optional<ExpressionNode>.Empty(); return ParseReturn();
if (!TryExpectSymbol(Symbol.Semicolon))
{
value = ParseExpression();
ExpectSymbol(Symbol.Semicolon);
}
return new ReturnNode(value);
} }
case Symbol.Let: case Symbol.Let:
{ {
var name = ExpectIdentifier().Value; return ParseVariableAssignment();
ExpectSymbol(Symbol.Assign); }
var value = ParseExpression(); case Symbol.If:
ExpectSymbol(Symbol.Semicolon); {
return new VariableAssignmentNode(name, value); return ParseIf();
} }
default: default:
{ {
@@ -184,6 +177,44 @@ public class Parser
} }
} }
private ReturnNode ParseReturn()
{
var value = Optional<ExpressionNode>.Empty();
if (!TryExpectSymbol(Symbol.Semicolon))
{
value = ParseExpression();
ExpectSymbol(Symbol.Semicolon);
}
return new ReturnNode(value);
}
private VariableAssignmentNode ParseVariableAssignment()
{
var name = ExpectIdentifier().Value;
ExpectSymbol(Symbol.Assign);
var value = ParseExpression();
ExpectSymbol(Symbol.Semicolon);
return new VariableAssignmentNode(name, value);
}
private IfNode ParseIf()
{
var condition = ParseExpression();
var body = ParseBlock();
var elseStatement = Optional<Variant<IfNode, BlockNode>>.Empty();
if (TryExpectSymbol(Symbol.Else))
{
elseStatement = TryExpectSymbol(Symbol.If)
? (Variant<IfNode, BlockNode>)ParseIf()
: (Variant<IfNode, BlockNode>)ParseBlock();
}
return new IfNode(condition, body, elseStatement);
}
private ExpressionNode ParseExpression(int precedence = 0) private ExpressionNode ParseExpression(int precedence = 0)
{ {
var left = ParsePrimaryExpression(); var left = ParsePrimaryExpression();

View File

@@ -82,6 +82,9 @@ public class ExpressionTyper
case FuncCallStatementNode funcCall: case FuncCallStatementNode funcCall:
PopulateFuncCallStatement(funcCall); PopulateFuncCallStatement(funcCall);
break; break;
case IfNode ifStatement:
PopulateIf(ifStatement);
break;
case ReturnNode returnNode: case ReturnNode returnNode:
PopulateReturn(returnNode); PopulateReturn(returnNode);
break; break;
@@ -107,6 +110,20 @@ public class ExpressionTyper
} }
} }
private void PopulateIf(IfNode ifStatement)
{
PopulateExpression(ifStatement.Condition);
PopulateBlock(ifStatement.Body);
if (ifStatement.Else.HasValue)
{
ifStatement.Else.Value.Match
(
PopulateIf,
PopulateBlock
);
}
}
private void PopulateSyscallStatement(SyscallStatementNode syscall) private void PopulateSyscallStatement(SyscallStatementNode syscall)
{ {
foreach (var parameter in syscall.Syscall.Parameters) foreach (var parameter in syscall.Syscall.Parameters)