Big type inference improvements

This commit is contained in:
nub31
2025-10-22 12:55:31 +02:00
parent e2da6cccff
commit 93cef598e8
8 changed files with 205 additions and 163 deletions

View File

@@ -7,10 +7,7 @@ using NubLang.Syntax;
var diagnostics = new List<Diagnostic>();
var syntaxTrees = new List<SyntaxTree>();
var nubFiles = args.Where(x => Path.GetExtension(x) == ".nub").ToArray();
var objectFileArgs = args.Where(x => Path.GetExtension(x) is ".o" or ".a").ToArray();
foreach (var file in nubFiles)
foreach (var file in args)
{
var tokenizer = new Tokenizer(file, File.ReadAllText(file));
tokenizer.Tokenize();
@@ -26,7 +23,7 @@ foreach (var file in nubFiles)
var modules = Module.Collect(syntaxTrees);
var compilationUnits = new List<CompilationUnit>();
for (var i = 0; i < nubFiles.Length; i++)
for (var i = 0; i < args.Length; i++)
{
var typeChecker = new TypeChecker(syntaxTrees[i], modules);
var compilationUnit = typeChecker.Check();
@@ -49,9 +46,9 @@ var cPaths = new List<string>();
Directory.CreateDirectory(".build");
for (var i = 0; i < nubFiles.Length; i++)
for (var i = 0; i < args.Length; i++)
{
var file = nubFiles[i];
var file = args[i];
var compilationUnit = compilationUnits[i];
var generator = new Generator(compilationUnit);

View File

@@ -80,17 +80,29 @@ public abstract record LValueExpressionNode(List<Token> Tokens, NubType Type) :
public abstract record RValueExpressionNode(List<Token> Tokens, NubType Type) : ExpressionNode(Tokens, Type);
public record StringLiteralNode(List<Token> Tokens, NubType Type, string Value) : RValueExpressionNode(Tokens, Type);
public record StringLiteralNode(List<Token> Tokens, string Value) : RValueExpressionNode(Tokens, new NubStringType());
public record CStringLiteralNode(List<Token> Tokens, NubType Type, string Value) : RValueExpressionNode(Tokens, Type);
public record CStringLiteralNode(List<Token> Tokens, string Value) : RValueExpressionNode(Tokens, new NubCStringType());
public record IntLiteralNode(List<Token> Tokens, NubType Type, long Value) : RValueExpressionNode(Tokens, Type);
public record I8LiteralNode(List<Token> Tokens, sbyte Value) : RValueExpressionNode(Tokens, new NubIntType(true, 8));
public record UIntLiteralNode(List<Token> Tokens, NubType Type, ulong Value) : RValueExpressionNode(Tokens, Type);
public record I16LiteralNode(List<Token> Tokens, short Value) : RValueExpressionNode(Tokens, new NubIntType(true, 16));
public record Float32LiteralNode(List<Token> Tokens, NubType Type, float Value) : RValueExpressionNode(Tokens, Type);
public record I32LiteralNode(List<Token> Tokens, int Value) : RValueExpressionNode(Tokens, new NubIntType(true, 32));
public record Float64LiteralNode(List<Token> Tokens, NubType Type, double Value) : RValueExpressionNode(Tokens, Type);
public record I64LiteralNode(List<Token> Tokens, long Value) : RValueExpressionNode(Tokens, new NubIntType(true, 64));
public record U8LiteralNode(List<Token> Tokens, byte Value) : RValueExpressionNode(Tokens, new NubIntType(false, 8));
public record U16LiteralNode(List<Token> Tokens, ushort Value) : RValueExpressionNode(Tokens, new NubIntType(false, 16));
public record U32LiteralNode(List<Token> Tokens, uint Value) : RValueExpressionNode(Tokens, new NubIntType(false, 32));
public record U64LiteralNode(List<Token> Tokens, ulong Value) : RValueExpressionNode(Tokens, new NubIntType(false, 64));
public record Float32LiteralNode(List<Token> Tokens, float Value) : RValueExpressionNode(Tokens, new NubFloatType(32));
public record Float64LiteralNode(List<Token> Tokens, double Value) : RValueExpressionNode(Tokens, new NubFloatType(64));
public record BoolLiteralNode(List<Token> Tokens, NubType Type, bool Value) : RValueExpressionNode(Tokens, Type);
@@ -114,13 +126,13 @@ public record AddressOfNode(List<Token> Tokens, NubType Type, LValueExpressionNo
public record StructFieldAccessNode(List<Token> Tokens, NubType Type, ExpressionNode Target, string Field) : LValueExpressionNode(Tokens, Type);
public record StructInitializerNode(List<Token> Tokens, NubStructType StructType, Dictionary<string, ExpressionNode> Initializers) : RValueExpressionNode(Tokens, StructType);
public record StructInitializerNode(List<Token> Tokens, NubType Type, Dictionary<string, ExpressionNode> Initializers) : RValueExpressionNode(Tokens, Type);
public record DereferenceNode(List<Token> Tokens, NubType Type, ExpressionNode Target) : LValueExpressionNode(Tokens, Type);
public record ConvertIntNode(List<Token> Tokens, NubType Type, ExpressionNode Value, NubIntType ValueType, NubIntType TargetType) : RValueExpressionNode(Tokens, Type);
public record ConvertIntNode(List<Token> Tokens, ExpressionNode Value, int StartWidth, int TargetWidth, bool StartSignedness, bool TargetSignedness) : RValueExpressionNode(Tokens, new NubIntType(TargetSignedness, TargetWidth));
public record ConvertFloatNode(List<Token> Tokens, NubType Type, ExpressionNode Value, NubFloatType ValueType, NubFloatType TargetType) : RValueExpressionNode(Tokens, Type);
public record ConvertFloatNode(List<Token> Tokens, ExpressionNode Value, int StartWidth, int TargetWidth) : RValueExpressionNode(Tokens, new NubFloatType(TargetWidth));
public record ConvertStringToCStringNode(List<Token> Tokens, ExpressionNode Value) : RValueExpressionNode(Tokens, new NubCStringType());

View File

@@ -10,7 +10,6 @@ public sealed class TypeChecker
private readonly Dictionary<string, Module> _importedModules;
private readonly Stack<Scope> _scopes = [];
private readonly Stack<NubType> _funcReturnTypes = [];
private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new();
private readonly HashSet<(string Module, string Name)> _resolvingTypes = [];
@@ -29,7 +28,6 @@ public sealed class TypeChecker
public CompilationUnit Check()
{
_scopes.Clear();
_funcReturnTypes.Clear();
_typeCache.Clear();
_resolvingTypes.Clear();
@@ -117,26 +115,26 @@ public sealed class TypeChecker
BlockNode? body = null;
if (node.Body != null)
{
_funcReturnTypes.Push(prototype.ReturnType);
body = CheckBlock(node.Body);
if (!AlwaysReturns(body))
using (BeginScope())
{
if (prototype.ReturnType is NubVoidType)
CurrentScope.SetReturnType(prototype.ReturnType);
body = CheckBlock(node.Body);
if (!AlwaysReturns(body))
{
body.Statements.Add(new ReturnNode(node.Tokens.Skip(node.Tokens.Count - 1).ToList(), null));
}
else
{
Diagnostics.Add(Diagnostic
.Error("Not all code paths return a value")
.At(node.Body)
.Build());
if (prototype.ReturnType is NubVoidType)
{
body.Statements.Add(new ReturnNode(node.Tokens.Skip(node.Tokens.Count - 1).ToList(), null));
}
else
{
Diagnostics.Add(Diagnostic
.Error("Not all code paths return a value")
.At(node.Body)
.Build());
}
}
}
_funcReturnTypes.Pop();
}
return new FuncNode(node.Tokens, prototype, body);
@@ -151,12 +149,21 @@ public sealed class TypeChecker
}
var value = CheckExpression(statement.Value, lValue.Type);
if (value.Type != lValue.Type)
{
throw new TypeCheckerException(Diagnostic
.Error($"Cannot assign {value.Type} to {lValue.Type}")
.At(statement.Value)
.Build());
}
return new AssignmentNode(statement.Tokens, lValue, value);
}
private IfNode CheckIf(IfSyntax statement)
{
var condition = CheckExpression(statement.Condition, new NubBoolType());
var condition = CheckExpression(statement.Condition);
var body = CheckBlock(statement.Body);
Variant<IfNode, BlockNode>? elseStatement = null;
if (statement.Else.HasValue)
@@ -173,7 +180,8 @@ public sealed class TypeChecker
if (statement.Value != null)
{
value = CheckExpression(statement.Value, _funcReturnTypes.Peek());
var expectedReturnType = CurrentScope.GetReturnType();
value = CheckExpression(statement.Value, expectedReturnType);
}
return new ReturnNode(statement.Tokens, value);
@@ -203,12 +211,26 @@ public sealed class TypeChecker
if (statement.Assignment != null)
{
assignmentNode = CheckExpression(statement.Assignment, type);
type ??= assignmentNode.Type;
if (type == null)
{
type = assignmentNode.Type;
}
else if (assignmentNode.Type != type)
{
throw new TypeCheckerException(Diagnostic
.Error($"Cannot assign {assignmentNode.Type} to variable of type {type}")
.At(statement.Assignment)
.Build());
}
}
if (type == null)
{
throw new TypeCheckerException(Diagnostic.Error($"Cannot infer type of variable {statement.Name}").At(statement).Build());
throw new TypeCheckerException(Diagnostic
.Error($"Cannot infer type of variable {statement.Name}")
.At(statement)
.Build());
}
CurrentScope.DeclareVariable(new Variable(statement.Name, type));
@@ -218,7 +240,7 @@ public sealed class TypeChecker
private WhileNode CheckWhile(WhileSyntax statement)
{
var condition = CheckExpression(statement.Condition, new NubBoolType());
var condition = CheckExpression(statement.Condition);
var body = CheckBlock(statement.Body);
return new WhileNode(statement.Tokens, condition, body);
}
@@ -236,64 +258,32 @@ public sealed class TypeChecker
private ExpressionNode CheckExpression(ExpressionSyntax node, NubType? expectedType = null)
{
var result = node switch
return node switch
{
AddressOfSyntax expression => CheckAddressOf(expression),
ArrayIndexAccessSyntax expression => CheckArrayIndexAccess(expression),
ArrayInitializerSyntax expression => CheckArrayInitializer(expression),
AddressOfSyntax expression => CheckAddressOf(expression, expectedType),
ArrayIndexAccessSyntax expression => CheckArrayIndexAccess(expression, expectedType),
ArrayInitializerSyntax expression => CheckArrayInitializer(expression, expectedType),
BinaryExpressionSyntax expression => CheckBinaryExpression(expression, expectedType),
UnaryExpressionSyntax expression => CheckUnaryExpression(expression, expectedType),
DereferenceSyntax expression => CheckDereference(expression),
FuncCallSyntax expression => CheckFuncCall(expression),
LocalIdentifierSyntax expression => CheckLocalIdentifier(expression),
ModuleIdentifierSyntax expression => CheckModuleIdentifier(expression),
BoolLiteralSyntax expression => CheckBoolLiteral(expression),
DereferenceSyntax expression => CheckDereference(expression, expectedType),
FuncCallSyntax expression => CheckFuncCall(expression, expectedType),
LocalIdentifierSyntax expression => CheckLocalIdentifier(expression, expectedType),
ModuleIdentifierSyntax expression => CheckModuleIdentifier(expression, expectedType),
BoolLiteralSyntax expression => CheckBoolLiteral(expression, expectedType),
StringLiteralSyntax expression => CheckStringLiteral(expression, expectedType),
IntLiteralSyntax expression => CheckIntLiteral(expression, expectedType),
FloatLiteralSyntax expression => CheckFloatLiteral(expression, expectedType),
MemberAccessSyntax expression => CheckMemberAccess(expression),
MemberAccessSyntax expression => CheckMemberAccess(expression, expectedType),
StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType),
InterpretBuiltinSyntax expression => CheckExpression(expression.Target) with { Type = ResolveType(expression.Type) },
InterpretBuiltinSyntax expression => CheckExpression(expression.Target, expectedType) with { Type = ResolveType(expression.Type) },
SizeBuiltinSyntax expression => new SizeBuiltinNode(node.Tokens, new NubIntType(false, 64), ResolveType(expression.Type)),
FloatToIntBuiltinSyntax expression => CheckFloatToInt(expression),
FloatToIntBuiltinSyntax expression => CheckFloatToInt(expression, expectedType),
_ => throw new ArgumentOutOfRangeException(nameof(node))
};
if (expectedType == null || result.Type == expectedType)
{
return result;
}
if (result.Type is NubStringType && expectedType is NubCStringType)
{
return new ConvertStringToCStringNode(node.Tokens, result);
}
if (result.Type is NubCStringType && expectedType is NubStringType)
{
return new ConvertCStringToStringNode(node.Tokens, result);
}
if (result.Type is NubIntType sourceIntType && expectedType is NubIntType targetIntType)
{
if (sourceIntType.Signed == targetIntType.Signed && sourceIntType.Width < targetIntType.Width)
{
return new ConvertIntNode(node.Tokens, targetIntType, result, sourceIntType, targetIntType);
}
}
if (result.Type is NubFloatType sourceFloatType && expectedType is NubFloatType targetFloatType)
{
if (sourceFloatType.Width < targetFloatType.Width)
{
return new ConvertFloatNode(node.Tokens, targetFloatType, result, sourceFloatType, targetFloatType);
}
}
throw new TypeCheckerException(Diagnostic.Error($"Cannot convert {result.Type} to {expectedType}").At(node).Build());
}
private FloatToIntBuiltinNode CheckFloatToInt(FloatToIntBuiltinSyntax expression)
// todo(nub31): Infer int type instead of explicit type syntax
private FloatToIntBuiltinNode CheckFloatToInt(FloatToIntBuiltinSyntax expression, NubType? _)
{
var value = CheckExpression(expression.Value);
if (value.Type is not NubFloatType sourceFloatType)
@@ -316,9 +306,9 @@ public sealed class TypeChecker
return new FloatToIntBuiltinNode(expression.Tokens, targetIntType, value, sourceFloatType, targetIntType);
}
private AddressOfNode CheckAddressOf(AddressOfSyntax expression)
private AddressOfNode CheckAddressOf(AddressOfSyntax expression, NubType? expectedType)
{
var target = CheckExpression(expression.Target);
var target = CheckExpression(expression.Target, (expectedType as NubPointerType)?.BaseType);
if (target is not LValueExpressionNode lvalue)
{
throw new TypeCheckerException(Diagnostic.Error("Cannot take address of an rvalue").At(expression).Build());
@@ -328,7 +318,7 @@ public sealed class TypeChecker
return new AddressOfNode(expression.Tokens, type, lvalue);
}
private ExpressionNode CheckArrayIndexAccess(ArrayIndexAccessSyntax expression)
private ExpressionNode CheckArrayIndexAccess(ArrayIndexAccessSyntax expression, NubType? _)
{
var index = CheckExpression(expression.Index);
if (index.Type is not NubIntType)
@@ -349,7 +339,8 @@ public sealed class TypeChecker
};
}
private ArrayInitializerNode CheckArrayInitializer(ArrayInitializerSyntax expression)
// todo(nub31): Allow type inference instead of specifying type in syntax. Something like just []
private ArrayInitializerNode CheckArrayInitializer(ArrayInitializerSyntax expression, NubType? _)
{
var elementType = ResolveType(expression.ElementType);
var type = new NubArrayType(elementType);
@@ -387,7 +378,6 @@ public sealed class TypeChecker
case BinaryOperatorSyntax.Equal:
case BinaryOperatorSyntax.NotEqual:
{
// note(nub31): Expected type should not be passed down since the resulting type is different than operands
var left = CheckExpression(expression.Left);
if (left.Type is not NubIntType and not NubFloatType and not NubBoolType)
{
@@ -406,7 +396,6 @@ public sealed class TypeChecker
case BinaryOperatorSyntax.LessThan:
case BinaryOperatorSyntax.LessThanOrEqual:
{
// note(nub31): Expected type should not be passed down since the resulting type is different than operands
var left = CheckExpression(expression.Left);
if (left.Type is not NubIntType and not NubFloatType)
{
@@ -423,7 +412,6 @@ public sealed class TypeChecker
case BinaryOperatorSyntax.LogicalAnd:
case BinaryOperatorSyntax.LogicalOr:
{
// note(nub31): Expected type should not be passed down since the resulting type is different than operands
var left = CheckExpression(expression.Left);
if (left.Type is not NubBoolType)
{
@@ -440,16 +428,16 @@ public sealed class TypeChecker
case BinaryOperatorSyntax.Plus:
{
var left = CheckExpression(expression.Left, expectedType);
if (left.Type is NubIntType or NubFloatType or NubStringType or NubCStringType)
if (left.Type is not NubIntType and not NubFloatType and not NubStringType and not NubCStringType)
{
var right = CheckExpression(expression.Right, left.Type);
return new BinaryExpressionNode(expression.Tokens, left.Type, left, op, right);
throw new TypeCheckerException(Diagnostic
.Error("The plus operator must be used with int, float, cstring or string types")
.At(expression.Left)
.Build());
}
throw new TypeCheckerException(Diagnostic
.Error("The plus operator must be used with int, float, cstring or string types")
.At(expression.Left)
.Build());
var right = CheckExpression(expression.Right, left.Type);
return new BinaryExpressionNode(expression.Tokens, left.Type, left, op, right);
}
case BinaryOperatorSyntax.Minus:
case BinaryOperatorSyntax.Multiply:
@@ -532,7 +520,7 @@ public sealed class TypeChecker
}
}
private DereferenceNode CheckDereference(DereferenceSyntax expression)
private DereferenceNode CheckDereference(DereferenceSyntax expression, NubType? _)
{
var target = CheckExpression(expression.Target);
if (target.Type is not NubPointerType pointerType)
@@ -543,7 +531,7 @@ public sealed class TypeChecker
return new DereferenceNode(expression.Tokens, pointerType.BaseType, target);
}
private FuncCallNode CheckFuncCall(FuncCallSyntax expression)
private FuncCallNode CheckFuncCall(FuncCallSyntax expression, NubType? _)
{
var accessor = CheckExpression(expression.Expression);
if (accessor.Type is not NubFuncType funcType)
@@ -563,13 +551,13 @@ public sealed class TypeChecker
for (var i = 0; i < expression.Parameters.Count; i++)
{
var parameter = expression.Parameters[i];
var expectedType = funcType.Parameters[i];
var expectedParameterType = funcType.Parameters[i];
var parameterExpression = CheckExpression(parameter, expectedType);
if (parameterExpression.Type != expectedType)
var parameterExpression = CheckExpression(parameter, expectedParameterType);
if (parameterExpression.Type != expectedParameterType)
{
throw new TypeCheckerException(Diagnostic
.Error($"Parameter {i + 1} does not match the type {expectedType} for function {funcType}")
.Error($"Parameter {i + 1} does not match the type {expectedParameterType} for function {funcType}")
.At(parameter)
.Build());
}
@@ -580,7 +568,7 @@ public sealed class TypeChecker
return new FuncCallNode(expression.Tokens, funcType.ReturnType, accessor, parameters);
}
private ExpressionNode CheckLocalIdentifier(LocalIdentifierSyntax expression)
private ExpressionNode CheckLocalIdentifier(LocalIdentifierSyntax expression, NubType? _)
{
var scopeIdent = CurrentScope.LookupVariable(expression.Name);
if (scopeIdent != null)
@@ -601,7 +589,7 @@ public sealed class TypeChecker
throw new TypeCheckerException(Diagnostic.Error($"Symbol {expression.Name} not found").At(expression).Build());
}
private ExpressionNode CheckModuleIdentifier(ModuleIdentifierSyntax expression)
private ExpressionNode CheckModuleIdentifier(ModuleIdentifierSyntax expression, NubType? _)
{
if (!_importedModules.TryGetValue(expression.Module, out var module))
{
@@ -630,51 +618,62 @@ public sealed class TypeChecker
private ExpressionNode CheckStringLiteral(StringLiteralSyntax expression, NubType? expectedType)
{
return expectedType switch
if (expectedType is NubCStringType)
{
NubCStringType => new CStringLiteralNode(expression.Tokens, expectedType, expression.Value),
NubStringType => new StringLiteralNode(expression.Tokens, expectedType, expression.Value),
_ => new StringLiteralNode(expression.Tokens, new NubStringType(), expression.Value)
};
return new CStringLiteralNode(expression.Tokens, expression.Value);
}
return new StringLiteralNode(expression.Tokens, expression.Value);
}
private ExpressionNode CheckIntLiteral(IntLiteralSyntax expression, NubType? expectedType)
{
if (expectedType is NubIntType intType)
{
return intType.Signed
? new IntLiteralNode(expression.Tokens, intType, Convert.ToInt64(expression.Value, expression.Base))
: new UIntLiteralNode(expression.Tokens, intType, Convert.ToUInt64(expression.Value, expression.Base));
return intType.Width switch
{
8 => intType.Signed ? new I8LiteralNode(expression.Tokens, Convert.ToSByte(expression.Value, expression.Base)) : new U8LiteralNode(expression.Tokens, Convert.ToByte(expression.Value, expression.Base)),
16 => intType.Signed ? new I16LiteralNode(expression.Tokens, Convert.ToInt16(expression.Value, expression.Base)) : new U16LiteralNode(expression.Tokens, Convert.ToUInt16(expression.Value, expression.Base)),
32 => intType.Signed ? new I32LiteralNode(expression.Tokens, Convert.ToInt32(expression.Value, expression.Base)) : new U32LiteralNode(expression.Tokens, Convert.ToUInt32(expression.Value, expression.Base)),
64 => intType.Signed ? new I64LiteralNode(expression.Tokens, Convert.ToInt64(expression.Value, expression.Base)) : new U64LiteralNode(expression.Tokens, Convert.ToUInt64(expression.Value, expression.Base)),
_ => throw new ArgumentOutOfRangeException()
};
}
if (expectedType is NubFloatType floatType)
{
return floatType.Width switch
{
32 => new Float32LiteralNode(expression.Tokens, floatType, Convert.ToSingle(Convert.ToInt64(expression.Value, expression.Base))),
64 => new Float64LiteralNode(expression.Tokens, floatType, Convert.ToDouble(Convert.ToInt64(expression.Value, expression.Base))),
32 => new Float32LiteralNode(expression.Tokens, Convert.ToSingle(expression.Value)),
64 => new Float64LiteralNode(expression.Tokens, Convert.ToDouble(expression.Value)),
_ => throw new ArgumentOutOfRangeException()
};
}
var type = new NubIntType(true, 64);
return new IntLiteralNode(expression.Tokens, type, Convert.ToInt64(expression.Value, expression.Base));
return new I64LiteralNode(expression.Tokens, Convert.ToInt64(expression.Value, expression.Base));
}
private ExpressionNode CheckFloatLiteral(FloatLiteralSyntax expression, NubType? expectedType)
{
var type = expectedType as NubFloatType ?? new NubFloatType(64);
return type.Width == 32
? new Float32LiteralNode(expression.Tokens, type, Convert.ToSingle(expression.Value))
: new Float64LiteralNode(expression.Tokens, type, Convert.ToDouble(expression.Value));
if (expectedType is NubFloatType floatType)
{
return floatType.Width switch
{
32 => new Float32LiteralNode(expression.Tokens, Convert.ToSingle(expression.Value)),
64 => new Float64LiteralNode(expression.Tokens, Convert.ToDouble(expression.Value)),
_ => throw new ArgumentOutOfRangeException()
};
}
return new Float64LiteralNode(expression.Tokens, Convert.ToDouble(expression.Value));
}
private BoolLiteralNode CheckBoolLiteral(BoolLiteralSyntax expression)
private BoolLiteralNode CheckBoolLiteral(BoolLiteralSyntax expression, NubType? _)
{
return new BoolLiteralNode(expression.Tokens, new NubBoolType(), expression.Value);
}
private ExpressionNode CheckMemberAccess(MemberAccessSyntax expression)
private ExpressionNode CheckMemberAccess(MemberAccessSyntax expression, NubType? _)
{
var target = CheckExpression(expression.Target);
switch (target.Type)
@@ -897,6 +896,7 @@ public record Variable(string Name, NubType Type);
public class Scope(string module, Scope? parent = null)
{
private NubType? _returnType;
private readonly List<Variable> _variables = [];
public string Module { get; } = module;
@@ -905,6 +905,16 @@ public class Scope(string module, Scope? parent = null)
_variables.Add(variable);
}
public void SetReturnType(NubType returnType)
{
_returnType = returnType;
}
public NubType? GetReturnType()
{
return _returnType ?? parent?.GetReturnType();
}
public Variable? LookupVariable(string name)
{
var variable = _variables.FirstOrDefault(x => x.Name == name);

View File

@@ -14,12 +14,12 @@ public static class CType
NubFloatType floatType => CreateFloatType(floatType, variableName),
NubCStringType => "char*" + (variableName != null ? $" {variableName}" : ""),
NubPointerType ptr => CreatePointerType(ptr, variableName),
NubSliceType nubSliceType => "slice" + (variableName != null ? $" {variableName}" : ""),
NubStringType nubStringType => "string" + (variableName != null ? $" {variableName}" : ""),
NubSliceType => "slice" + (variableName != null ? $" {variableName}" : ""),
NubStringType => "string" + (variableName != null ? $" {variableName}" : ""),
NubConstArrayType arr => CreateConstArrayType(arr, variableName),
NubArrayType arr => CreateArrayType(arr, variableName),
NubFuncType fn => CreateFuncType(fn, variableName),
NubStructType st => $"{st.Name}" + (variableName != null ? $" {variableName}" : ""),
NubStructType st => $"{st.Module}_{st.Name}" + (variableName != null ? $" {variableName}" : ""),
_ => throw new NotSupportedException($"C type generation not supported for: {type}")
};
}

View File

@@ -35,12 +35,10 @@ public class Generator
public string Emit()
{
_writer.WriteLine("#include <stdint.h>");
_writer.WriteLine("#include <stdarg.h>");
_writer.WriteLine("#include <stddef.h>");
_writer.WriteLine();
_writer.WriteLine("""
#include <stdint.h>
#include <stddef.h>
typedef struct
{
size_t length;
@@ -274,14 +272,20 @@ public class Generator
FloatToIntBuiltinNode floatToIntBuiltinNode => EmitFloatToIntBuiltin(floatToIntBuiltinNode),
FuncCallNode funcCallNode => EmitFuncCall(funcCallNode),
FuncIdentifierNode funcIdentifierNode => FuncName(funcIdentifierNode.Module, funcIdentifierNode.Name, funcIdentifierNode.ExternSymbol),
IntLiteralNode intLiteralNode => EmitIntLiteral(intLiteralNode),
AddressOfNode addressOfNode => EmitAddressOf(addressOfNode),
SizeBuiltinNode sizeBuiltinNode => $"sizeof({CType.Create(sizeBuiltinNode.TargetType)})",
SliceIndexAccessNode sliceIndexAccessNode => EmitSliceArrayIndexAccess(sliceIndexAccessNode),
StringLiteralNode stringLiteralNode => EmitStringLiteral(stringLiteralNode),
StructFieldAccessNode structFieldAccessNode => EmitStructFieldAccess(structFieldAccessNode),
StructInitializerNode structInitializerNode => EmitStructInitializer(structInitializerNode),
UIntLiteralNode uIntLiteralNode => EmitUIntLiteral(uIntLiteralNode),
I8LiteralNode i8LiteralNode => EmitI8Literal(i8LiteralNode),
I16LiteralNode i16LiteralNode => EmitI16Literal(i16LiteralNode),
I32LiteralNode i32LiteralNode => EmitI32Literal(i32LiteralNode),
I64LiteralNode i64LiteralNode => EmitI64Literal(i64LiteralNode),
U8LiteralNode u8LiteralNode => EmitU8Literal(u8LiteralNode),
U16LiteralNode u16LiteralNode => EmitU16Literal(u16LiteralNode),
U32LiteralNode u32LiteralNode => EmitU32Literal(u32LiteralNode),
U64LiteralNode u64LiteralNode => EmitU64Literal(u64LiteralNode),
UnaryExpressionNode unaryExpressionNode => EmitUnaryExpression(unaryExpressionNode),
VariableIdentifierNode variableIdentifierNode => variableIdentifierNode.Name,
_ => throw new ArgumentOutOfRangeException(nameof(expressionNode))
@@ -345,7 +349,7 @@ public class Generator
private string EmitConvertFloat(ConvertFloatNode convertFloatNode)
{
var value = EmitExpression(convertFloatNode.Value);
var targetCast = convertFloatNode.TargetType.Width switch
var targetCast = convertFloatNode.TargetWidth switch
{
32 => "f32",
64 => "f64",
@@ -358,12 +362,12 @@ public class Generator
private string EmitConvertInt(ConvertIntNode convertIntNode)
{
var value = EmitExpression(convertIntNode.Value);
var targetType = convertIntNode.TargetType.Width switch
var targetType = convertIntNode.TargetWidth switch
{
8 => convertIntNode.TargetType.Signed ? "int8_t" : "uint8_t",
16 => convertIntNode.TargetType.Signed ? "int16_t" : "uint16_t",
32 => convertIntNode.TargetType.Signed ? "int32_t" : "uint32_t",
64 => convertIntNode.TargetType.Signed ? "int64_t" : "uint64_t",
8 => convertIntNode.TargetSignedness ? "int8_t" : "uint8_t",
16 => convertIntNode.TargetSignedness ? "int16_t" : "uint16_t",
32 => convertIntNode.TargetSignedness ? "int32_t" : "uint32_t",
64 => convertIntNode.TargetSignedness ? "int64_t" : "uint64_t",
_ => throw new ArgumentOutOfRangeException()
};
return $"({targetType}){value}";
@@ -420,21 +424,10 @@ public class Generator
private string EmitFuncCall(FuncCallNode funcCallNode)
{
var name = EmitExpression(funcCallNode.Expression);
var parameterNames = funcCallNode.Parameters.Select(x => EmitExpression(x)).ToList();
var parameterNames = funcCallNode.Parameters.Select(EmitExpression).ToList();
return $"{name}({string.Join(", ", parameterNames)})";
}
private string EmitIntLiteral(IntLiteralNode intLiteralNode)
{
var type = (NubIntType)intLiteralNode.Type;
return type.Width switch
{
8 or 16 or 32 => intLiteralNode.Value.ToString(),
64 => intLiteralNode.Value + "LL",
_ => throw new ArgumentOutOfRangeException()
};
}
private string EmitAddressOf(AddressOfNode addressOfNode)
{
var value = EmitExpression(addressOfNode.LValue);
@@ -477,15 +470,44 @@ public class Generator
return $"({CType.Create(structInitializerNode.Type)}){{{initString}}}";
}
private string EmitUIntLiteral(UIntLiteralNode uIntLiteralNode)
private string EmitI8Literal(I8LiteralNode i8LiteralNode)
{
var type = (NubIntType)uIntLiteralNode.Type;
return type.Width switch
{
8 or 16 or 32 => uIntLiteralNode.Value + "U",
64 => uIntLiteralNode.Value + "ULL",
_ => throw new ArgumentOutOfRangeException()
};
return i8LiteralNode.Value.ToString();
}
private string EmitI16Literal(I16LiteralNode i16LiteralNode)
{
return i16LiteralNode.Value.ToString();
}
private string EmitI32Literal(I32LiteralNode i32LiteralNode)
{
return i32LiteralNode.Value.ToString();
}
private string EmitI64Literal(I64LiteralNode i64LiteralNode)
{
return i64LiteralNode.Value + "LL";
}
private string EmitU8Literal(U8LiteralNode u8LiteralNode)
{
return u8LiteralNode.Value.ToString();
}
private string EmitU16Literal(U16LiteralNode u16LiteralNode)
{
return u16LiteralNode.Value.ToString();
}
private string EmitU32Literal(U32LiteralNode u32LiteralNode)
{
return u32LiteralNode.Value.ToString();
}
private string EmitU64Literal(U64LiteralNode u64LiteralNode)
{
return u64LiteralNode.Value + "ULL";
}
private string EmitUnaryExpression(UnaryExpressionNode unaryExpressionNode)

View File

@@ -5,5 +5,6 @@ extern "puts" func puts(text: cstring)
extern "main" func main(argc: i64, argv: [?]cstring): i64
{
let x = [23]i32
puts("test")
return x[0]
}

View File

@@ -1,5 +1,5 @@
.build/out: main.nub
nubc main.nub
clang $$(nubc main.nub) -o .build/out
clean:
@rm -r .build 2>/dev/null || true

View File

@@ -1,5 +1,5 @@
.build/out: main.nub
nubc main.nub
clang $$(nubc main.nub) -o .build/out
clean:
@rm -r .build 2>/dev/null || true