Big type inference improvements

This commit is contained in:
nub31
2025-10-22 12:55:31 +02:00
parent e2da6cccff
commit 93cef598e8
8 changed files with 205 additions and 163 deletions

View File

@@ -7,10 +7,7 @@ using NubLang.Syntax;
var diagnostics = new List<Diagnostic>(); var diagnostics = new List<Diagnostic>();
var syntaxTrees = new List<SyntaxTree>(); var syntaxTrees = new List<SyntaxTree>();
var nubFiles = args.Where(x => Path.GetExtension(x) == ".nub").ToArray(); foreach (var file in args)
var objectFileArgs = args.Where(x => Path.GetExtension(x) is ".o" or ".a").ToArray();
foreach (var file in nubFiles)
{ {
var tokenizer = new Tokenizer(file, File.ReadAllText(file)); var tokenizer = new Tokenizer(file, File.ReadAllText(file));
tokenizer.Tokenize(); tokenizer.Tokenize();
@@ -26,7 +23,7 @@ foreach (var file in nubFiles)
var modules = Module.Collect(syntaxTrees); var modules = Module.Collect(syntaxTrees);
var compilationUnits = new List<CompilationUnit>(); var compilationUnits = new List<CompilationUnit>();
for (var i = 0; i < nubFiles.Length; i++) for (var i = 0; i < args.Length; i++)
{ {
var typeChecker = new TypeChecker(syntaxTrees[i], modules); var typeChecker = new TypeChecker(syntaxTrees[i], modules);
var compilationUnit = typeChecker.Check(); var compilationUnit = typeChecker.Check();
@@ -49,9 +46,9 @@ var cPaths = new List<string>();
Directory.CreateDirectory(".build"); Directory.CreateDirectory(".build");
for (var i = 0; i < nubFiles.Length; i++) for (var i = 0; i < args.Length; i++)
{ {
var file = nubFiles[i]; var file = args[i];
var compilationUnit = compilationUnits[i]; var compilationUnit = compilationUnits[i];
var generator = new Generator(compilationUnit); var generator = new Generator(compilationUnit);

View File

@@ -80,17 +80,29 @@ public abstract record LValueExpressionNode(List<Token> Tokens, NubType Type) :
public abstract record RValueExpressionNode(List<Token> Tokens, NubType Type) : ExpressionNode(Tokens, Type); public abstract record RValueExpressionNode(List<Token> Tokens, NubType Type) : ExpressionNode(Tokens, Type);
public record StringLiteralNode(List<Token> Tokens, NubType Type, string Value) : RValueExpressionNode(Tokens, Type); public record StringLiteralNode(List<Token> Tokens, string Value) : RValueExpressionNode(Tokens, new NubStringType());
public record CStringLiteralNode(List<Token> Tokens, NubType Type, string Value) : RValueExpressionNode(Tokens, Type); public record CStringLiteralNode(List<Token> Tokens, string Value) : RValueExpressionNode(Tokens, new NubCStringType());
public record IntLiteralNode(List<Token> Tokens, NubType Type, long Value) : RValueExpressionNode(Tokens, Type); public record I8LiteralNode(List<Token> Tokens, sbyte Value) : RValueExpressionNode(Tokens, new NubIntType(true, 8));
public record UIntLiteralNode(List<Token> Tokens, NubType Type, ulong Value) : RValueExpressionNode(Tokens, Type); public record I16LiteralNode(List<Token> Tokens, short Value) : RValueExpressionNode(Tokens, new NubIntType(true, 16));
public record Float32LiteralNode(List<Token> Tokens, NubType Type, float Value) : RValueExpressionNode(Tokens, Type); public record I32LiteralNode(List<Token> Tokens, int Value) : RValueExpressionNode(Tokens, new NubIntType(true, 32));
public record Float64LiteralNode(List<Token> Tokens, NubType Type, double Value) : RValueExpressionNode(Tokens, Type); public record I64LiteralNode(List<Token> Tokens, long Value) : RValueExpressionNode(Tokens, new NubIntType(true, 64));
public record U8LiteralNode(List<Token> Tokens, byte Value) : RValueExpressionNode(Tokens, new NubIntType(false, 8));
public record U16LiteralNode(List<Token> Tokens, ushort Value) : RValueExpressionNode(Tokens, new NubIntType(false, 16));
public record U32LiteralNode(List<Token> Tokens, uint Value) : RValueExpressionNode(Tokens, new NubIntType(false, 32));
public record U64LiteralNode(List<Token> Tokens, ulong Value) : RValueExpressionNode(Tokens, new NubIntType(false, 64));
public record Float32LiteralNode(List<Token> Tokens, float Value) : RValueExpressionNode(Tokens, new NubFloatType(32));
public record Float64LiteralNode(List<Token> Tokens, double Value) : RValueExpressionNode(Tokens, new NubFloatType(64));
public record BoolLiteralNode(List<Token> Tokens, NubType Type, bool Value) : RValueExpressionNode(Tokens, Type); public record BoolLiteralNode(List<Token> Tokens, NubType Type, bool Value) : RValueExpressionNode(Tokens, Type);
@@ -114,13 +126,13 @@ public record AddressOfNode(List<Token> Tokens, NubType Type, LValueExpressionNo
public record StructFieldAccessNode(List<Token> Tokens, NubType Type, ExpressionNode Target, string Field) : LValueExpressionNode(Tokens, Type); public record StructFieldAccessNode(List<Token> Tokens, NubType Type, ExpressionNode Target, string Field) : LValueExpressionNode(Tokens, Type);
public record StructInitializerNode(List<Token> Tokens, NubStructType StructType, Dictionary<string, ExpressionNode> Initializers) : RValueExpressionNode(Tokens, StructType); public record StructInitializerNode(List<Token> Tokens, NubType Type, Dictionary<string, ExpressionNode> Initializers) : RValueExpressionNode(Tokens, Type);
public record DereferenceNode(List<Token> Tokens, NubType Type, ExpressionNode Target) : LValueExpressionNode(Tokens, Type); public record DereferenceNode(List<Token> Tokens, NubType Type, ExpressionNode Target) : LValueExpressionNode(Tokens, Type);
public record ConvertIntNode(List<Token> Tokens, NubType Type, ExpressionNode Value, NubIntType ValueType, NubIntType TargetType) : RValueExpressionNode(Tokens, Type); public record ConvertIntNode(List<Token> Tokens, ExpressionNode Value, int StartWidth, int TargetWidth, bool StartSignedness, bool TargetSignedness) : RValueExpressionNode(Tokens, new NubIntType(TargetSignedness, TargetWidth));
public record ConvertFloatNode(List<Token> Tokens, NubType Type, ExpressionNode Value, NubFloatType ValueType, NubFloatType TargetType) : RValueExpressionNode(Tokens, Type); public record ConvertFloatNode(List<Token> Tokens, ExpressionNode Value, int StartWidth, int TargetWidth) : RValueExpressionNode(Tokens, new NubFloatType(TargetWidth));
public record ConvertStringToCStringNode(List<Token> Tokens, ExpressionNode Value) : RValueExpressionNode(Tokens, new NubCStringType()); public record ConvertStringToCStringNode(List<Token> Tokens, ExpressionNode Value) : RValueExpressionNode(Tokens, new NubCStringType());

View File

@@ -10,7 +10,6 @@ public sealed class TypeChecker
private readonly Dictionary<string, Module> _importedModules; private readonly Dictionary<string, Module> _importedModules;
private readonly Stack<Scope> _scopes = []; private readonly Stack<Scope> _scopes = [];
private readonly Stack<NubType> _funcReturnTypes = [];
private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new(); private readonly Dictionary<(string Module, string Name), NubType> _typeCache = new();
private readonly HashSet<(string Module, string Name)> _resolvingTypes = []; private readonly HashSet<(string Module, string Name)> _resolvingTypes = [];
@@ -29,7 +28,6 @@ public sealed class TypeChecker
public CompilationUnit Check() public CompilationUnit Check()
{ {
_scopes.Clear(); _scopes.Clear();
_funcReturnTypes.Clear();
_typeCache.Clear(); _typeCache.Clear();
_resolvingTypes.Clear(); _resolvingTypes.Clear();
@@ -117,8 +115,9 @@ public sealed class TypeChecker
BlockNode? body = null; BlockNode? body = null;
if (node.Body != null) if (node.Body != null)
{ {
_funcReturnTypes.Push(prototype.ReturnType); using (BeginScope())
{
CurrentScope.SetReturnType(prototype.ReturnType);
body = CheckBlock(node.Body); body = CheckBlock(node.Body);
if (!AlwaysReturns(body)) if (!AlwaysReturns(body))
@@ -135,8 +134,7 @@ public sealed class TypeChecker
.Build()); .Build());
} }
} }
}
_funcReturnTypes.Pop();
} }
return new FuncNode(node.Tokens, prototype, body); return new FuncNode(node.Tokens, prototype, body);
@@ -151,12 +149,21 @@ public sealed class TypeChecker
} }
var value = CheckExpression(statement.Value, lValue.Type); var value = CheckExpression(statement.Value, lValue.Type);
if (value.Type != lValue.Type)
{
throw new TypeCheckerException(Diagnostic
.Error($"Cannot assign {value.Type} to {lValue.Type}")
.At(statement.Value)
.Build());
}
return new AssignmentNode(statement.Tokens, lValue, value); return new AssignmentNode(statement.Tokens, lValue, value);
} }
private IfNode CheckIf(IfSyntax statement) private IfNode CheckIf(IfSyntax statement)
{ {
var condition = CheckExpression(statement.Condition, new NubBoolType()); var condition = CheckExpression(statement.Condition);
var body = CheckBlock(statement.Body); var body = CheckBlock(statement.Body);
Variant<IfNode, BlockNode>? elseStatement = null; Variant<IfNode, BlockNode>? elseStatement = null;
if (statement.Else.HasValue) if (statement.Else.HasValue)
@@ -173,7 +180,8 @@ public sealed class TypeChecker
if (statement.Value != null) if (statement.Value != null)
{ {
value = CheckExpression(statement.Value, _funcReturnTypes.Peek()); var expectedReturnType = CurrentScope.GetReturnType();
value = CheckExpression(statement.Value, expectedReturnType);
} }
return new ReturnNode(statement.Tokens, value); return new ReturnNode(statement.Tokens, value);
@@ -203,12 +211,26 @@ public sealed class TypeChecker
if (statement.Assignment != null) if (statement.Assignment != null)
{ {
assignmentNode = CheckExpression(statement.Assignment, type); assignmentNode = CheckExpression(statement.Assignment, type);
type ??= assignmentNode.Type;
if (type == null)
{
type = assignmentNode.Type;
}
else if (assignmentNode.Type != type)
{
throw new TypeCheckerException(Diagnostic
.Error($"Cannot assign {assignmentNode.Type} to variable of type {type}")
.At(statement.Assignment)
.Build());
}
} }
if (type == null) if (type == null)
{ {
throw new TypeCheckerException(Diagnostic.Error($"Cannot infer type of variable {statement.Name}").At(statement).Build()); throw new TypeCheckerException(Diagnostic
.Error($"Cannot infer type of variable {statement.Name}")
.At(statement)
.Build());
} }
CurrentScope.DeclareVariable(new Variable(statement.Name, type)); CurrentScope.DeclareVariable(new Variable(statement.Name, type));
@@ -218,7 +240,7 @@ public sealed class TypeChecker
private WhileNode CheckWhile(WhileSyntax statement) private WhileNode CheckWhile(WhileSyntax statement)
{ {
var condition = CheckExpression(statement.Condition, new NubBoolType()); var condition = CheckExpression(statement.Condition);
var body = CheckBlock(statement.Body); var body = CheckBlock(statement.Body);
return new WhileNode(statement.Tokens, condition, body); return new WhileNode(statement.Tokens, condition, body);
} }
@@ -236,64 +258,32 @@ public sealed class TypeChecker
private ExpressionNode CheckExpression(ExpressionSyntax node, NubType? expectedType = null) private ExpressionNode CheckExpression(ExpressionSyntax node, NubType? expectedType = null)
{ {
var result = node switch return node switch
{ {
AddressOfSyntax expression => CheckAddressOf(expression), AddressOfSyntax expression => CheckAddressOf(expression, expectedType),
ArrayIndexAccessSyntax expression => CheckArrayIndexAccess(expression), ArrayIndexAccessSyntax expression => CheckArrayIndexAccess(expression, expectedType),
ArrayInitializerSyntax expression => CheckArrayInitializer(expression), ArrayInitializerSyntax expression => CheckArrayInitializer(expression, expectedType),
BinaryExpressionSyntax expression => CheckBinaryExpression(expression, expectedType), BinaryExpressionSyntax expression => CheckBinaryExpression(expression, expectedType),
UnaryExpressionSyntax expression => CheckUnaryExpression(expression, expectedType), UnaryExpressionSyntax expression => CheckUnaryExpression(expression, expectedType),
DereferenceSyntax expression => CheckDereference(expression), DereferenceSyntax expression => CheckDereference(expression, expectedType),
FuncCallSyntax expression => CheckFuncCall(expression), FuncCallSyntax expression => CheckFuncCall(expression, expectedType),
LocalIdentifierSyntax expression => CheckLocalIdentifier(expression), LocalIdentifierSyntax expression => CheckLocalIdentifier(expression, expectedType),
ModuleIdentifierSyntax expression => CheckModuleIdentifier(expression), ModuleIdentifierSyntax expression => CheckModuleIdentifier(expression, expectedType),
BoolLiteralSyntax expression => CheckBoolLiteral(expression), BoolLiteralSyntax expression => CheckBoolLiteral(expression, expectedType),
StringLiteralSyntax expression => CheckStringLiteral(expression, expectedType), StringLiteralSyntax expression => CheckStringLiteral(expression, expectedType),
IntLiteralSyntax expression => CheckIntLiteral(expression, expectedType), IntLiteralSyntax expression => CheckIntLiteral(expression, expectedType),
FloatLiteralSyntax expression => CheckFloatLiteral(expression, expectedType), FloatLiteralSyntax expression => CheckFloatLiteral(expression, expectedType),
MemberAccessSyntax expression => CheckMemberAccess(expression), MemberAccessSyntax expression => CheckMemberAccess(expression, expectedType),
StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType), StructInitializerSyntax expression => CheckStructInitializer(expression, expectedType),
InterpretBuiltinSyntax expression => CheckExpression(expression.Target) with { Type = ResolveType(expression.Type) }, InterpretBuiltinSyntax expression => CheckExpression(expression.Target, expectedType) with { Type = ResolveType(expression.Type) },
SizeBuiltinSyntax expression => new SizeBuiltinNode(node.Tokens, new NubIntType(false, 64), ResolveType(expression.Type)), SizeBuiltinSyntax expression => new SizeBuiltinNode(node.Tokens, new NubIntType(false, 64), ResolveType(expression.Type)),
FloatToIntBuiltinSyntax expression => CheckFloatToInt(expression), FloatToIntBuiltinSyntax expression => CheckFloatToInt(expression, expectedType),
_ => throw new ArgumentOutOfRangeException(nameof(node)) _ => throw new ArgumentOutOfRangeException(nameof(node))
}; };
if (expectedType == null || result.Type == expectedType)
{
return result;
} }
if (result.Type is NubStringType && expectedType is NubCStringType) // todo(nub31): Infer int type instead of explicit type syntax
{ private FloatToIntBuiltinNode CheckFloatToInt(FloatToIntBuiltinSyntax expression, NubType? _)
return new ConvertStringToCStringNode(node.Tokens, result);
}
if (result.Type is NubCStringType && expectedType is NubStringType)
{
return new ConvertCStringToStringNode(node.Tokens, result);
}
if (result.Type is NubIntType sourceIntType && expectedType is NubIntType targetIntType)
{
if (sourceIntType.Signed == targetIntType.Signed && sourceIntType.Width < targetIntType.Width)
{
return new ConvertIntNode(node.Tokens, targetIntType, result, sourceIntType, targetIntType);
}
}
if (result.Type is NubFloatType sourceFloatType && expectedType is NubFloatType targetFloatType)
{
if (sourceFloatType.Width < targetFloatType.Width)
{
return new ConvertFloatNode(node.Tokens, targetFloatType, result, sourceFloatType, targetFloatType);
}
}
throw new TypeCheckerException(Diagnostic.Error($"Cannot convert {result.Type} to {expectedType}").At(node).Build());
}
private FloatToIntBuiltinNode CheckFloatToInt(FloatToIntBuiltinSyntax expression)
{ {
var value = CheckExpression(expression.Value); var value = CheckExpression(expression.Value);
if (value.Type is not NubFloatType sourceFloatType) if (value.Type is not NubFloatType sourceFloatType)
@@ -316,9 +306,9 @@ public sealed class TypeChecker
return new FloatToIntBuiltinNode(expression.Tokens, targetIntType, value, sourceFloatType, targetIntType); return new FloatToIntBuiltinNode(expression.Tokens, targetIntType, value, sourceFloatType, targetIntType);
} }
private AddressOfNode CheckAddressOf(AddressOfSyntax expression) private AddressOfNode CheckAddressOf(AddressOfSyntax expression, NubType? expectedType)
{ {
var target = CheckExpression(expression.Target); var target = CheckExpression(expression.Target, (expectedType as NubPointerType)?.BaseType);
if (target is not LValueExpressionNode lvalue) if (target is not LValueExpressionNode lvalue)
{ {
throw new TypeCheckerException(Diagnostic.Error("Cannot take address of an rvalue").At(expression).Build()); throw new TypeCheckerException(Diagnostic.Error("Cannot take address of an rvalue").At(expression).Build());
@@ -328,7 +318,7 @@ public sealed class TypeChecker
return new AddressOfNode(expression.Tokens, type, lvalue); return new AddressOfNode(expression.Tokens, type, lvalue);
} }
private ExpressionNode CheckArrayIndexAccess(ArrayIndexAccessSyntax expression) private ExpressionNode CheckArrayIndexAccess(ArrayIndexAccessSyntax expression, NubType? _)
{ {
var index = CheckExpression(expression.Index); var index = CheckExpression(expression.Index);
if (index.Type is not NubIntType) if (index.Type is not NubIntType)
@@ -349,7 +339,8 @@ public sealed class TypeChecker
}; };
} }
private ArrayInitializerNode CheckArrayInitializer(ArrayInitializerSyntax expression) // todo(nub31): Allow type inference instead of specifying type in syntax. Something like just []
private ArrayInitializerNode CheckArrayInitializer(ArrayInitializerSyntax expression, NubType? _)
{ {
var elementType = ResolveType(expression.ElementType); var elementType = ResolveType(expression.ElementType);
var type = new NubArrayType(elementType); var type = new NubArrayType(elementType);
@@ -387,7 +378,6 @@ public sealed class TypeChecker
case BinaryOperatorSyntax.Equal: case BinaryOperatorSyntax.Equal:
case BinaryOperatorSyntax.NotEqual: case BinaryOperatorSyntax.NotEqual:
{ {
// note(nub31): Expected type should not be passed down since the resulting type is different than operands
var left = CheckExpression(expression.Left); var left = CheckExpression(expression.Left);
if (left.Type is not NubIntType and not NubFloatType and not NubBoolType) if (left.Type is not NubIntType and not NubFloatType and not NubBoolType)
{ {
@@ -406,7 +396,6 @@ public sealed class TypeChecker
case BinaryOperatorSyntax.LessThan: case BinaryOperatorSyntax.LessThan:
case BinaryOperatorSyntax.LessThanOrEqual: case BinaryOperatorSyntax.LessThanOrEqual:
{ {
// note(nub31): Expected type should not be passed down since the resulting type is different than operands
var left = CheckExpression(expression.Left); var left = CheckExpression(expression.Left);
if (left.Type is not NubIntType and not NubFloatType) if (left.Type is not NubIntType and not NubFloatType)
{ {
@@ -423,7 +412,6 @@ public sealed class TypeChecker
case BinaryOperatorSyntax.LogicalAnd: case BinaryOperatorSyntax.LogicalAnd:
case BinaryOperatorSyntax.LogicalOr: case BinaryOperatorSyntax.LogicalOr:
{ {
// note(nub31): Expected type should not be passed down since the resulting type is different than operands
var left = CheckExpression(expression.Left); var left = CheckExpression(expression.Left);
if (left.Type is not NubBoolType) if (left.Type is not NubBoolType)
{ {
@@ -440,17 +428,17 @@ public sealed class TypeChecker
case BinaryOperatorSyntax.Plus: case BinaryOperatorSyntax.Plus:
{ {
var left = CheckExpression(expression.Left, expectedType); var left = CheckExpression(expression.Left, expectedType);
if (left.Type is NubIntType or NubFloatType or NubStringType or NubCStringType) if (left.Type is not NubIntType and not NubFloatType and not NubStringType and not NubCStringType)
{ {
var right = CheckExpression(expression.Right, left.Type);
return new BinaryExpressionNode(expression.Tokens, left.Type, left, op, right);
}
throw new TypeCheckerException(Diagnostic throw new TypeCheckerException(Diagnostic
.Error("The plus operator must be used with int, float, cstring or string types") .Error("The plus operator must be used with int, float, cstring or string types")
.At(expression.Left) .At(expression.Left)
.Build()); .Build());
} }
var right = CheckExpression(expression.Right, left.Type);
return new BinaryExpressionNode(expression.Tokens, left.Type, left, op, right);
}
case BinaryOperatorSyntax.Minus: case BinaryOperatorSyntax.Minus:
case BinaryOperatorSyntax.Multiply: case BinaryOperatorSyntax.Multiply:
case BinaryOperatorSyntax.Divide: case BinaryOperatorSyntax.Divide:
@@ -532,7 +520,7 @@ public sealed class TypeChecker
} }
} }
private DereferenceNode CheckDereference(DereferenceSyntax expression) private DereferenceNode CheckDereference(DereferenceSyntax expression, NubType? _)
{ {
var target = CheckExpression(expression.Target); var target = CheckExpression(expression.Target);
if (target.Type is not NubPointerType pointerType) if (target.Type is not NubPointerType pointerType)
@@ -543,7 +531,7 @@ public sealed class TypeChecker
return new DereferenceNode(expression.Tokens, pointerType.BaseType, target); return new DereferenceNode(expression.Tokens, pointerType.BaseType, target);
} }
private FuncCallNode CheckFuncCall(FuncCallSyntax expression) private FuncCallNode CheckFuncCall(FuncCallSyntax expression, NubType? _)
{ {
var accessor = CheckExpression(expression.Expression); var accessor = CheckExpression(expression.Expression);
if (accessor.Type is not NubFuncType funcType) if (accessor.Type is not NubFuncType funcType)
@@ -563,13 +551,13 @@ public sealed class TypeChecker
for (var i = 0; i < expression.Parameters.Count; i++) for (var i = 0; i < expression.Parameters.Count; i++)
{ {
var parameter = expression.Parameters[i]; var parameter = expression.Parameters[i];
var expectedType = funcType.Parameters[i]; var expectedParameterType = funcType.Parameters[i];
var parameterExpression = CheckExpression(parameter, expectedType); var parameterExpression = CheckExpression(parameter, expectedParameterType);
if (parameterExpression.Type != expectedType) if (parameterExpression.Type != expectedParameterType)
{ {
throw new TypeCheckerException(Diagnostic throw new TypeCheckerException(Diagnostic
.Error($"Parameter {i + 1} does not match the type {expectedType} for function {funcType}") .Error($"Parameter {i + 1} does not match the type {expectedParameterType} for function {funcType}")
.At(parameter) .At(parameter)
.Build()); .Build());
} }
@@ -580,7 +568,7 @@ public sealed class TypeChecker
return new FuncCallNode(expression.Tokens, funcType.ReturnType, accessor, parameters); return new FuncCallNode(expression.Tokens, funcType.ReturnType, accessor, parameters);
} }
private ExpressionNode CheckLocalIdentifier(LocalIdentifierSyntax expression) private ExpressionNode CheckLocalIdentifier(LocalIdentifierSyntax expression, NubType? _)
{ {
var scopeIdent = CurrentScope.LookupVariable(expression.Name); var scopeIdent = CurrentScope.LookupVariable(expression.Name);
if (scopeIdent != null) if (scopeIdent != null)
@@ -601,7 +589,7 @@ public sealed class TypeChecker
throw new TypeCheckerException(Diagnostic.Error($"Symbol {expression.Name} not found").At(expression).Build()); throw new TypeCheckerException(Diagnostic.Error($"Symbol {expression.Name} not found").At(expression).Build());
} }
private ExpressionNode CheckModuleIdentifier(ModuleIdentifierSyntax expression) private ExpressionNode CheckModuleIdentifier(ModuleIdentifierSyntax expression, NubType? _)
{ {
if (!_importedModules.TryGetValue(expression.Module, out var module)) if (!_importedModules.TryGetValue(expression.Module, out var module))
{ {
@@ -630,51 +618,62 @@ public sealed class TypeChecker
private ExpressionNode CheckStringLiteral(StringLiteralSyntax expression, NubType? expectedType) private ExpressionNode CheckStringLiteral(StringLiteralSyntax expression, NubType? expectedType)
{ {
return expectedType switch if (expectedType is NubCStringType)
{ {
NubCStringType => new CStringLiteralNode(expression.Tokens, expectedType, expression.Value), return new CStringLiteralNode(expression.Tokens, expression.Value);
NubStringType => new StringLiteralNode(expression.Tokens, expectedType, expression.Value), }
_ => new StringLiteralNode(expression.Tokens, new NubStringType(), expression.Value)
}; return new StringLiteralNode(expression.Tokens, expression.Value);
} }
private ExpressionNode CheckIntLiteral(IntLiteralSyntax expression, NubType? expectedType) private ExpressionNode CheckIntLiteral(IntLiteralSyntax expression, NubType? expectedType)
{ {
if (expectedType is NubIntType intType) if (expectedType is NubIntType intType)
{ {
return intType.Signed return intType.Width switch
? new IntLiteralNode(expression.Tokens, intType, Convert.ToInt64(expression.Value, expression.Base)) {
: new UIntLiteralNode(expression.Tokens, intType, Convert.ToUInt64(expression.Value, expression.Base)); 8 => intType.Signed ? new I8LiteralNode(expression.Tokens, Convert.ToSByte(expression.Value, expression.Base)) : new U8LiteralNode(expression.Tokens, Convert.ToByte(expression.Value, expression.Base)),
16 => intType.Signed ? new I16LiteralNode(expression.Tokens, Convert.ToInt16(expression.Value, expression.Base)) : new U16LiteralNode(expression.Tokens, Convert.ToUInt16(expression.Value, expression.Base)),
32 => intType.Signed ? new I32LiteralNode(expression.Tokens, Convert.ToInt32(expression.Value, expression.Base)) : new U32LiteralNode(expression.Tokens, Convert.ToUInt32(expression.Value, expression.Base)),
64 => intType.Signed ? new I64LiteralNode(expression.Tokens, Convert.ToInt64(expression.Value, expression.Base)) : new U64LiteralNode(expression.Tokens, Convert.ToUInt64(expression.Value, expression.Base)),
_ => throw new ArgumentOutOfRangeException()
};
} }
if (expectedType is NubFloatType floatType) if (expectedType is NubFloatType floatType)
{ {
return floatType.Width switch return floatType.Width switch
{ {
32 => new Float32LiteralNode(expression.Tokens, floatType, Convert.ToSingle(Convert.ToInt64(expression.Value, expression.Base))), 32 => new Float32LiteralNode(expression.Tokens, Convert.ToSingle(expression.Value)),
64 => new Float64LiteralNode(expression.Tokens, floatType, Convert.ToDouble(Convert.ToInt64(expression.Value, expression.Base))), 64 => new Float64LiteralNode(expression.Tokens, Convert.ToDouble(expression.Value)),
_ => throw new ArgumentOutOfRangeException() _ => throw new ArgumentOutOfRangeException()
}; };
} }
var type = new NubIntType(true, 64); return new I64LiteralNode(expression.Tokens, Convert.ToInt64(expression.Value, expression.Base));
return new IntLiteralNode(expression.Tokens, type, Convert.ToInt64(expression.Value, expression.Base));
} }
private ExpressionNode CheckFloatLiteral(FloatLiteralSyntax expression, NubType? expectedType) private ExpressionNode CheckFloatLiteral(FloatLiteralSyntax expression, NubType? expectedType)
{ {
var type = expectedType as NubFloatType ?? new NubFloatType(64); if (expectedType is NubFloatType floatType)
return type.Width == 32 {
? new Float32LiteralNode(expression.Tokens, type, Convert.ToSingle(expression.Value)) return floatType.Width switch
: new Float64LiteralNode(expression.Tokens, type, Convert.ToDouble(expression.Value)); {
32 => new Float32LiteralNode(expression.Tokens, Convert.ToSingle(expression.Value)),
64 => new Float64LiteralNode(expression.Tokens, Convert.ToDouble(expression.Value)),
_ => throw new ArgumentOutOfRangeException()
};
} }
private BoolLiteralNode CheckBoolLiteral(BoolLiteralSyntax expression) return new Float64LiteralNode(expression.Tokens, Convert.ToDouble(expression.Value));
}
private BoolLiteralNode CheckBoolLiteral(BoolLiteralSyntax expression, NubType? _)
{ {
return new BoolLiteralNode(expression.Tokens, new NubBoolType(), expression.Value); return new BoolLiteralNode(expression.Tokens, new NubBoolType(), expression.Value);
} }
private ExpressionNode CheckMemberAccess(MemberAccessSyntax expression) private ExpressionNode CheckMemberAccess(MemberAccessSyntax expression, NubType? _)
{ {
var target = CheckExpression(expression.Target); var target = CheckExpression(expression.Target);
switch (target.Type) switch (target.Type)
@@ -897,6 +896,7 @@ public record Variable(string Name, NubType Type);
public class Scope(string module, Scope? parent = null) public class Scope(string module, Scope? parent = null)
{ {
private NubType? _returnType;
private readonly List<Variable> _variables = []; private readonly List<Variable> _variables = [];
public string Module { get; } = module; public string Module { get; } = module;
@@ -905,6 +905,16 @@ public class Scope(string module, Scope? parent = null)
_variables.Add(variable); _variables.Add(variable);
} }
public void SetReturnType(NubType returnType)
{
_returnType = returnType;
}
public NubType? GetReturnType()
{
return _returnType ?? parent?.GetReturnType();
}
public Variable? LookupVariable(string name) public Variable? LookupVariable(string name)
{ {
var variable = _variables.FirstOrDefault(x => x.Name == name); var variable = _variables.FirstOrDefault(x => x.Name == name);

View File

@@ -14,12 +14,12 @@ public static class CType
NubFloatType floatType => CreateFloatType(floatType, variableName), NubFloatType floatType => CreateFloatType(floatType, variableName),
NubCStringType => "char*" + (variableName != null ? $" {variableName}" : ""), NubCStringType => "char*" + (variableName != null ? $" {variableName}" : ""),
NubPointerType ptr => CreatePointerType(ptr, variableName), NubPointerType ptr => CreatePointerType(ptr, variableName),
NubSliceType nubSliceType => "slice" + (variableName != null ? $" {variableName}" : ""), NubSliceType => "slice" + (variableName != null ? $" {variableName}" : ""),
NubStringType nubStringType => "string" + (variableName != null ? $" {variableName}" : ""), NubStringType => "string" + (variableName != null ? $" {variableName}" : ""),
NubConstArrayType arr => CreateConstArrayType(arr, variableName), NubConstArrayType arr => CreateConstArrayType(arr, variableName),
NubArrayType arr => CreateArrayType(arr, variableName), NubArrayType arr => CreateArrayType(arr, variableName),
NubFuncType fn => CreateFuncType(fn, variableName), NubFuncType fn => CreateFuncType(fn, variableName),
NubStructType st => $"{st.Name}" + (variableName != null ? $" {variableName}" : ""), NubStructType st => $"{st.Module}_{st.Name}" + (variableName != null ? $" {variableName}" : ""),
_ => throw new NotSupportedException($"C type generation not supported for: {type}") _ => throw new NotSupportedException($"C type generation not supported for: {type}")
}; };
} }

View File

@@ -35,12 +35,10 @@ public class Generator
public string Emit() public string Emit()
{ {
_writer.WriteLine("#include <stdint.h>");
_writer.WriteLine("#include <stdarg.h>");
_writer.WriteLine("#include <stddef.h>");
_writer.WriteLine();
_writer.WriteLine(""" _writer.WriteLine("""
#include <stdint.h>
#include <stddef.h>
typedef struct typedef struct
{ {
size_t length; size_t length;
@@ -274,14 +272,20 @@ public class Generator
FloatToIntBuiltinNode floatToIntBuiltinNode => EmitFloatToIntBuiltin(floatToIntBuiltinNode), FloatToIntBuiltinNode floatToIntBuiltinNode => EmitFloatToIntBuiltin(floatToIntBuiltinNode),
FuncCallNode funcCallNode => EmitFuncCall(funcCallNode), FuncCallNode funcCallNode => EmitFuncCall(funcCallNode),
FuncIdentifierNode funcIdentifierNode => FuncName(funcIdentifierNode.Module, funcIdentifierNode.Name, funcIdentifierNode.ExternSymbol), FuncIdentifierNode funcIdentifierNode => FuncName(funcIdentifierNode.Module, funcIdentifierNode.Name, funcIdentifierNode.ExternSymbol),
IntLiteralNode intLiteralNode => EmitIntLiteral(intLiteralNode),
AddressOfNode addressOfNode => EmitAddressOf(addressOfNode), AddressOfNode addressOfNode => EmitAddressOf(addressOfNode),
SizeBuiltinNode sizeBuiltinNode => $"sizeof({CType.Create(sizeBuiltinNode.TargetType)})", SizeBuiltinNode sizeBuiltinNode => $"sizeof({CType.Create(sizeBuiltinNode.TargetType)})",
SliceIndexAccessNode sliceIndexAccessNode => EmitSliceArrayIndexAccess(sliceIndexAccessNode), SliceIndexAccessNode sliceIndexAccessNode => EmitSliceArrayIndexAccess(sliceIndexAccessNode),
StringLiteralNode stringLiteralNode => EmitStringLiteral(stringLiteralNode), StringLiteralNode stringLiteralNode => EmitStringLiteral(stringLiteralNode),
StructFieldAccessNode structFieldAccessNode => EmitStructFieldAccess(structFieldAccessNode), StructFieldAccessNode structFieldAccessNode => EmitStructFieldAccess(structFieldAccessNode),
StructInitializerNode structInitializerNode => EmitStructInitializer(structInitializerNode), StructInitializerNode structInitializerNode => EmitStructInitializer(structInitializerNode),
UIntLiteralNode uIntLiteralNode => EmitUIntLiteral(uIntLiteralNode), I8LiteralNode i8LiteralNode => EmitI8Literal(i8LiteralNode),
I16LiteralNode i16LiteralNode => EmitI16Literal(i16LiteralNode),
I32LiteralNode i32LiteralNode => EmitI32Literal(i32LiteralNode),
I64LiteralNode i64LiteralNode => EmitI64Literal(i64LiteralNode),
U8LiteralNode u8LiteralNode => EmitU8Literal(u8LiteralNode),
U16LiteralNode u16LiteralNode => EmitU16Literal(u16LiteralNode),
U32LiteralNode u32LiteralNode => EmitU32Literal(u32LiteralNode),
U64LiteralNode u64LiteralNode => EmitU64Literal(u64LiteralNode),
UnaryExpressionNode unaryExpressionNode => EmitUnaryExpression(unaryExpressionNode), UnaryExpressionNode unaryExpressionNode => EmitUnaryExpression(unaryExpressionNode),
VariableIdentifierNode variableIdentifierNode => variableIdentifierNode.Name, VariableIdentifierNode variableIdentifierNode => variableIdentifierNode.Name,
_ => throw new ArgumentOutOfRangeException(nameof(expressionNode)) _ => throw new ArgumentOutOfRangeException(nameof(expressionNode))
@@ -345,7 +349,7 @@ public class Generator
private string EmitConvertFloat(ConvertFloatNode convertFloatNode) private string EmitConvertFloat(ConvertFloatNode convertFloatNode)
{ {
var value = EmitExpression(convertFloatNode.Value); var value = EmitExpression(convertFloatNode.Value);
var targetCast = convertFloatNode.TargetType.Width switch var targetCast = convertFloatNode.TargetWidth switch
{ {
32 => "f32", 32 => "f32",
64 => "f64", 64 => "f64",
@@ -358,12 +362,12 @@ public class Generator
private string EmitConvertInt(ConvertIntNode convertIntNode) private string EmitConvertInt(ConvertIntNode convertIntNode)
{ {
var value = EmitExpression(convertIntNode.Value); var value = EmitExpression(convertIntNode.Value);
var targetType = convertIntNode.TargetType.Width switch var targetType = convertIntNode.TargetWidth switch
{ {
8 => convertIntNode.TargetType.Signed ? "int8_t" : "uint8_t", 8 => convertIntNode.TargetSignedness ? "int8_t" : "uint8_t",
16 => convertIntNode.TargetType.Signed ? "int16_t" : "uint16_t", 16 => convertIntNode.TargetSignedness ? "int16_t" : "uint16_t",
32 => convertIntNode.TargetType.Signed ? "int32_t" : "uint32_t", 32 => convertIntNode.TargetSignedness ? "int32_t" : "uint32_t",
64 => convertIntNode.TargetType.Signed ? "int64_t" : "uint64_t", 64 => convertIntNode.TargetSignedness ? "int64_t" : "uint64_t",
_ => throw new ArgumentOutOfRangeException() _ => throw new ArgumentOutOfRangeException()
}; };
return $"({targetType}){value}"; return $"({targetType}){value}";
@@ -420,21 +424,10 @@ public class Generator
private string EmitFuncCall(FuncCallNode funcCallNode) private string EmitFuncCall(FuncCallNode funcCallNode)
{ {
var name = EmitExpression(funcCallNode.Expression); var name = EmitExpression(funcCallNode.Expression);
var parameterNames = funcCallNode.Parameters.Select(x => EmitExpression(x)).ToList(); var parameterNames = funcCallNode.Parameters.Select(EmitExpression).ToList();
return $"{name}({string.Join(", ", parameterNames)})"; return $"{name}({string.Join(", ", parameterNames)})";
} }
private string EmitIntLiteral(IntLiteralNode intLiteralNode)
{
var type = (NubIntType)intLiteralNode.Type;
return type.Width switch
{
8 or 16 or 32 => intLiteralNode.Value.ToString(),
64 => intLiteralNode.Value + "LL",
_ => throw new ArgumentOutOfRangeException()
};
}
private string EmitAddressOf(AddressOfNode addressOfNode) private string EmitAddressOf(AddressOfNode addressOfNode)
{ {
var value = EmitExpression(addressOfNode.LValue); var value = EmitExpression(addressOfNode.LValue);
@@ -477,15 +470,44 @@ public class Generator
return $"({CType.Create(structInitializerNode.Type)}){{{initString}}}"; return $"({CType.Create(structInitializerNode.Type)}){{{initString}}}";
} }
private string EmitUIntLiteral(UIntLiteralNode uIntLiteralNode) private string EmitI8Literal(I8LiteralNode i8LiteralNode)
{ {
var type = (NubIntType)uIntLiteralNode.Type; return i8LiteralNode.Value.ToString();
return type.Width switch }
private string EmitI16Literal(I16LiteralNode i16LiteralNode)
{ {
8 or 16 or 32 => uIntLiteralNode.Value + "U", return i16LiteralNode.Value.ToString();
64 => uIntLiteralNode.Value + "ULL", }
_ => throw new ArgumentOutOfRangeException()
}; private string EmitI32Literal(I32LiteralNode i32LiteralNode)
{
return i32LiteralNode.Value.ToString();
}
private string EmitI64Literal(I64LiteralNode i64LiteralNode)
{
return i64LiteralNode.Value + "LL";
}
private string EmitU8Literal(U8LiteralNode u8LiteralNode)
{
return u8LiteralNode.Value.ToString();
}
private string EmitU16Literal(U16LiteralNode u16LiteralNode)
{
return u16LiteralNode.Value.ToString();
}
private string EmitU32Literal(U32LiteralNode u32LiteralNode)
{
return u32LiteralNode.Value.ToString();
}
private string EmitU64Literal(U64LiteralNode u64LiteralNode)
{
return u64LiteralNode.Value + "ULL";
} }
private string EmitUnaryExpression(UnaryExpressionNode unaryExpressionNode) private string EmitUnaryExpression(UnaryExpressionNode unaryExpressionNode)

View File

@@ -5,5 +5,6 @@ extern "puts" func puts(text: cstring)
extern "main" func main(argc: i64, argv: [?]cstring): i64 extern "main" func main(argc: i64, argv: [?]cstring): i64
{ {
let x = [23]i32 let x = [23]i32
puts("test")
return x[0] return x[0]
} }

View File

@@ -1,5 +1,5 @@
.build/out: main.nub .build/out: main.nub
nubc main.nub clang $$(nubc main.nub) -o .build/out
clean: clean:
@rm -r .build 2>/dev/null || true @rm -r .build 2>/dev/null || true

View File

@@ -1,5 +1,5 @@
.build/out: main.nub .build/out: main.nub
nubc main.nub clang $$(nubc main.nub) -o .build/out
clean: clean:
@rm -r .build 2>/dev/null || true @rm -r .build 2>/dev/null || true