@@ -1,5 +1,6 @@
using System.Data.Common ;
using System.Data.Common ;
using System.Diagnostics.CodeAnalysis ;
using System.Diagnostics.CodeAnalysis ;
using System.Formats.Tar ;
namespace Compiler ;
namespace Compiler ;
@@ -21,6 +22,7 @@ public class TypeChecker
private readonly string fileName ;
private readonly string fileName ;
private readonly string currentModule ;
private readonly string currentModule ;
private readonly NodeDefinitionFunc function ;
private readonly NodeDefinitionFunc function ;
private NubType functionReturnType = null ! ;
private readonly ModuleGraph moduleGraph ;
private readonly ModuleGraph moduleGraph ;
private readonly Scope scope = new ( ) ;
private readonly Scope scope = new ( ) ;
@@ -31,7 +33,16 @@ public class TypeChecker
var parameters = new List < TypedNodeDefinitionFunc . Param > ( ) ;
var parameters = new List < TypedNodeDefinitionFunc . Param > ( ) ;
var invalidParameter = false ;
var invalidParameter = false ;
TypedNodeStatement ? body = null ;
TypedNodeStatement ? body = null ;
NubType ? returnType = null ;
try
{
functionReturnType = ResolveType ( function . ReturnType ) ;
}
catch ( CompileException e )
{
diagnostics . Add ( e . Diagnostic ) ;
return null ;
}
using ( scope . EnterScope ( ) )
using ( scope . EnterScope ( ) )
{
{
@@ -63,19 +74,10 @@ public class TypeChecker
diagnostics . Add ( e . Diagnostic ) ;
diagnostics . Add ( e . Diagnostic ) ;
}
}
try
if ( body = = null | | invalidParameter )
{
returnType = ResolveType ( function . ReturnType ) ;
}
catch ( CompileException e )
{
diagnostics . Add ( e . Diagnostic ) ;
}
if ( body = = null | | returnType is null | | invalidParameter )
return null ;
return null ;
return new TypedNodeDefinitionFunc ( function . Tokens , currentModule , function . Name , parameters , body , r eturnType) ;
return new TypedNodeDefinitionFunc ( function . Tokens , currentModule , function . Name , parameters , body , functionR eturnType) ;
}
}
}
}
@@ -97,7 +99,10 @@ public class TypeChecker
private TypedNodeStatementAssignment CheckStatementAssignment ( NodeStatementAssignment statement )
private TypedNodeStatementAssignment CheckStatementAssignment ( NodeStatementAssignment statement )
{
{
return new TypedNodeStatementAssignment ( statement . Tokens , CheckExpression ( statement . Target ) , CheckExpression ( statement . Value ) ) ;
var target = CheckExpression ( statement . Target , null ) ;
var value = CheckExpression ( statement . Value , target . Type ) ;
return new TypedNodeStatementAssignment ( statement . Tokens , target , value ) ;
}
}
private TypedNodeStatementBlock CheckStatementBlock ( NodeStatementBlock statement )
private TypedNodeStatementBlock CheckStatementBlock ( NodeStatementBlock statement )
@@ -114,27 +119,56 @@ public class TypeChecker
if ( statement . Expression is not NodeExpressionFuncCall funcCall )
if ( statement . Expression is not NodeExpressionFuncCall funcCall )
throw BasicError ( "Expected statement or function call" , statement ) ;
throw BasicError ( "Expected statement or function call" , statement ) ;
return new TypedNodeStatementFuncCall ( statement . Tokens , CheckExpression ( funcCall . Target ) , funcCall . Parameters . Select ( CheckExpression ) . ToList ( ) ) ;
var target = CheckExpression ( funcCall . Target , null ) ;
if ( target . Type is not NubTypeFunc funcType )
throw BasicError ( "Expected a function type" , target ) ;
if ( funcType . Parameters . Count ! = funcCall . Parameters . Count )
throw BasicError ( $"Expected {funcType.Parameters.Count} parameters but got {funcCall.Parameters.Count}" , funcCall ) ;
var parameters = new List < TypedNodeExpression > ( ) ;
for ( int i = 0 ; i < funcCall . Parameters . Count ; i + + )
{
parameters . Add ( CheckExpression ( funcCall . Parameters [ i ] , funcType . Parameters [ i ] ) ) ;
}
return new TypedNodeStatementFuncCall ( statement . Tokens , target , parameters ) ;
}
}
private TypedNodeStatementIf CheckStatementIf ( NodeStatementIf statement )
private TypedNodeStatementIf CheckStatementIf ( NodeStatementIf statement )
{
{
return new TypedNodeStatementIf ( statement . Tokens , CheckExpression ( statement . Condition ) , CheckStatement ( statement . ThenBlock ) , statement . ElseBlock = = null ? null : CheckStatement ( statement . ElseBlock ) ) ;
var condition = CheckExpression ( statement . Condition , NubTypeBool . Instance ) ;
if ( ! condition . Type . IsAssignableTo ( NubTypeBool . Instance ) )
throw BasicError ( "Condition part of if statement must be a boolean" , condition ) ;
var thenBlock = CheckStatement ( statement . ThenBlock ) ;
var elseBlock = statement . ElseBlock = = null ? null : CheckStatement ( statement . ElseBlock ) ;
return new TypedNodeStatementIf ( statement . Tokens , condition , thenBlock , elseBlock ) ;
}
}
private TypedNodeStatementReturn CheckStatementReturn ( NodeStatementReturn statement )
private TypedNodeStatementReturn CheckStatementReturn ( NodeStatementReturn statement )
{
{
return new TypedNodeStatementReturn ( statement . Tokens , CheckExpression ( statement . Value ) ) ;
var value = CheckExpression ( statement . Value , functionReturnType ) ;
if ( ! value . Type . IsAssignableTo ( functionReturnType ) )
throw BasicError ( $"Type of returned value ({value.Type}) is not assignable to the return type of the function ({functionReturnType})" , value ) ;
return new TypedNodeStatementReturn ( statement . Tokens , value ) ;
}
}
private TypedNodeStatementVariableDeclaration CheckStatementVariableDeclaration ( NodeStatementVariableDeclaration statement )
private TypedNodeStatementVariableDeclaration CheckStatementVariableDeclaration ( NodeStatementVariableDeclaration statement )
{
{
var type = ResolveType ( statement . Type ) ;
NubType ? type = null ;
var value = CheckExpression ( statement . Value ) ;
if ( statement . Type ! = null )
type = ResolveType ( statement . Type ) ;
if ( ! value . Type . IsAssignableTo ( type ) )
var value = CheckExpression ( statement . Value , type ) ;
if ( type is not null & & ! value . Type . IsAssignableTo ( type ) )
throw BasicError ( "Type of variable does match type of assigned value" , value ) ;
throw BasicError ( "Type of variable does match type of assigned value" , value ) ;
type ? ? = value . Type ;
scope . DeclareIdentifier ( statement . Name . Ident , type ) ;
scope . DeclareIdentifier ( statement . Name . Ident , type ) ;
return new TypedNodeStatementVariableDeclaration ( statement . Tokens , statement . Name , type , value ) ;
return new TypedNodeStatementVariableDeclaration ( statement . Tokens , statement . Name , type , value ) ;
@@ -142,15 +176,22 @@ public class TypeChecker
private TypedNodeStatementWhile CheckStatementWhile ( NodeStatementWhile statement )
private TypedNodeStatementWhile CheckStatementWhile ( NodeStatementWhile statement )
{
{
return new TypedNodeStatementWhile ( statement . Tokens , CheckExpression ( statement . Condition ) , CheckStatement ( statement . Body ) ) ;
var condition = CheckExpression ( statement . Condition , NubTypeBool . In stance ) ;
if ( ! condition . Type . IsAssignableTo ( NubTypeBool . Instance ) )
throw BasicError ( "Condition part of if statement must be a boolean" , condition ) ;
var body = CheckStatement ( statement . Body ) ;
return new TypedNodeStatementWhile ( statement . Tokens , condition , body ) ;
}
}
private TypedNodeStatementMatch CheckStatementMatch ( NodeStatementMatch statement )
private TypedNodeStatementMatch CheckStatementMatch ( NodeStatementMatch statement )
{
{
var cases = new List < TypedNodeStatementMatch . Case > ( ) ;
var target = CheckExpression ( statement . Target , null ) ;
var target = CheckExpression ( statement . Target ) ;
if ( target . Type is not NubTypeEnum enumType )
var enumType = ( NubTypeEnum ) target . Type ;
throw BasicError ( "A match statement can only be used on enum types" , target ) ;
var cases = new List < TypedNodeStatementMatch . Case > ( ) ;
foreach ( var @case in statement . Cases )
foreach ( var @case in statement . Cases )
{
{
using ( scope . EnterScope ( ) )
using ( scope . EnterScope ( ) )
@@ -164,28 +205,29 @@ public class TypeChecker
return new TypedNodeStatementMatch ( statement . Tokens , target , cases ) ;
return new TypedNodeStatementMatch ( statement . Tokens , target , cases ) ;
}
}
private TypedNodeExpression CheckExpression ( NodeExpression node )
private TypedNodeExpression CheckExpression ( NodeExpression node , NubType ? expectedType )
{
{
return node switch
return node switch
{
{
NodeExpressionBinary expression = > CheckExpressionBinary ( expression ) ,
NodeExpressionBinary expression = > CheckExpressionBinary ( expression , expectedType ),
NodeExpressionUnary expression = > CheckExpressionUnary ( expression ) ,
NodeExpressionUnary expression = > CheckExpressionUnary ( expression , expectedType ),
NodeExpressionBoolLiteral expression = > CheckExpressionBoolLiteral ( expression ) ,
NodeExpressionBoolLiteral expression = > CheckExpressionBoolLiteral ( expression , expectedType ),
NodeExpressionIdent expression = > CheckExpressionIdent ( expression ) ,
NodeExpressionIdent expression = > CheckExpressionIdent ( expression , expectedType ),
NodeExpressionIntLiteral expression = > CheckExpressionIntLiteral ( expression ) ,
NodeExpressionIntLiteral expression = > CheckExpressionIntLiteral ( expression , expectedType ),
NodeExpressionMemberAccess expression = > CheckExpressionMemberAccess ( expression ) ,
NodeExpressionMemberAccess expression = > CheckExpressionMemberAccess ( expression , expectedType ),
NodeExpressionFuncCall expression = > CheckExpressionFuncCall ( expression ) ,
NodeExpressionFuncCall expression = > CheckExpressionFuncCall ( expression , expectedType ),
NodeExpressionStringLiteral expression = > CheckExpressionStringLiteral ( expression ) ,
NodeExpressionStringLiteral expression = > CheckExpressionStringLiteral ( expression , expectedType ),
NodeExpressionStructLiteral expression = > CheckExpressionStructLiteral ( expression ) ,
NodeExpressionStructLiteral expression = > CheckExpressionStructLiteral ( expression , expectedType ),
NodeExpressionEnumLiteral expression = > CheckExpressionEnumLiteral ( expression ) ,
NodeExpressionEnumLiteral expression = > CheckExpressionEnumLiteral ( expression , expectedType ),
_ = > throw new ArgumentOutOfRangeException ( nameof ( node ) )
_ = > throw new ArgumentOutOfRangeException ( nameof ( node ) )
} ;
} ;
}
}
private TypedNodeExpressionBinary CheckExpressionBinary ( NodeExpressionBinary expression )
private TypedNodeExpressionBinary CheckExpressionBinary ( NodeExpressionBinary expression , NubType ? expectedType )
{
{
var left = CheckExpression ( expression . Left ) ;
// todo(nub31): Add proper inference here
var righ t = CheckExpression ( expression . Right ) ;
var lef t = CheckExpression ( expression . Left , null ) ;
var right = CheckExpression ( expression . Right , null ) ;
NubType type ;
NubType type ;
switch ( expression . Operation )
switch ( expression . Operation )
@@ -275,9 +317,10 @@ public class TypeChecker
} ;
} ;
}
}
private TypedNodeExpressionUnary CheckExpressionUnary ( NodeExpressionUnary expression )
private TypedNodeExpressionUnary CheckExpressionUnary ( NodeExpressionUnary expression , NubType ? expectedType )
{
{
var target = CheckExpression ( expression . Target ) ;
// todo(nub31): Add proper inference here
var target = CheckExpression ( expression . Target , null ) ;
NubType type ;
NubType type ;
switch ( expression . Operation )
switch ( expression . Operation )
@@ -315,12 +358,12 @@ public class TypeChecker
} ;
} ;
}
}
private TypedNodeExpressionBoolLiteral CheckExpressionBoolLiteral ( NodeExpressionBoolLiteral expression )
private TypedNodeExpressionBoolLiteral CheckExpressionBoolLiteral ( NodeExpressionBoolLiteral expression , NubType ? expectedType )
{
{
return new TypedNodeExpressionBoolLiteral ( expression . Tokens , NubTypeBool . Instance , expression . Value ) ;
return new TypedNodeExpressionBoolLiteral ( expression . Tokens , NubTypeBool . Instance , expression . Value ) ;
}
}
private TypedNodeExpression CheckExpressionIdent ( NodeExpressionIdent expression )
private TypedNodeExpression CheckExpressionIdent ( NodeExpressionIdent expression , NubType ? expectedType )
{
{
if ( expression . Sections . Count = = 1 )
if ( expression . Sections . Count = = 1 )
{
{
@@ -345,14 +388,14 @@ public class TypeChecker
throw BasicError ( $"Unknown identifier '{string.Join(" : : ", expression.Sections.Select(x => x.Ident))}'" , expression ) ;
throw BasicError ( $"Unknown identifier '{string.Join(" : : ", expression.Sections.Select(x => x.Ident))}'" , expression ) ;
}
}
private TypedNodeExpressionIntLiteral CheckExpressionIntLiteral ( NodeExpressionIntLiteral expression )
private TypedNodeExpressionIntLiteral CheckExpressionIntLiteral ( NodeExpressionIntLiteral expression , NubType ? expectedType )
{
{
return new TypedNodeExpressionIntLiteral ( expression . Tokens , NubTypeSInt . Get ( 32 ) , expression . Value ) ;
return new TypedNodeExpressionIntLiteral ( expression . Tokens , NubTypeSInt . Get ( 32 ) , expression . Value ) ;
}
}
private TypedNodeExpressionMemberAccess CheckExpressionMemberAccess ( NodeExpressionMemberAccess expression )
private TypedNodeExpressionMemberAccess CheckExpressionMemberAccess ( NodeExpressionMemberAccess expression , NubType ? expectedType )
{
{
var target = CheckExpression ( expression . Target ) ;
var target = CheckExpression ( expression . Target , null );
switch ( target . Type )
switch ( target . Type )
{
{
@@ -395,55 +438,105 @@ public class TypeChecker
}
}
}
}
private TypedNodeExpressionFuncCall CheckExpressionFuncCall ( NodeExpressionFuncCall expression )
private TypedNodeExpressionFuncCall CheckExpressionFuncCall ( NodeExpressionFuncCall expression , NubType ? expectedType )
{
{
var target = CheckExpression ( expression . Target ) ;
var target = CheckExpression ( expression . Target , null );
if ( target . Type is not NubTypeFunc funcType )
if ( target . Type is not NubTypeFunc funcType )
throw BasicError ( $"Cannot invoke function call on type '{target.Type}' ", target ) ;
throw BasicError ( "Expected a function type ", target ) ;
var parameters = expression . Parameters . Select ( CheckExpression ) . ToList ( ) ;
if ( funcType . Parameters . Count ! = expression . Parameters . Count )
throw BasicError ( $"Expected {funcType.Parameters.Count} parameters but got {expression.Parameters.Count}" , expression ) ;
var parameters = new List < TypedNodeExpression > ( ) ;
for ( int i = 0 ; i < expression . Parameters . Count ; i + + )
{
parameters . Add ( CheckExpression ( expression . Parameters [ i ] , funcType . Parameters [ i ] ) ) ;
}
return new TypedNodeExpressionFuncCall ( expression . Tokens , funcType . ReturnType , target , parameters ) ;
return new TypedNodeExpressionFuncCall ( expression . Tokens , funcType . ReturnType , target , parameters ) ;
}
}
private TypedNodeExpressionStringLiteral CheckExpressionStringLiteral ( NodeExpressionStringLiteral expression )
private TypedNodeExpressionStringLiteral CheckExpressionStringLiteral ( NodeExpressionStringLiteral expression , NubType ? expectedType )
{
{
return new TypedNodeExpressionStringLiteral ( expression . Tokens , NubTypeString . Instance , expression . Value ) ;
return new TypedNodeExpressionStringLiteral ( expression . Tokens , NubTypeString . Instance , expression . Value ) ;
}
}
private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral ( NodeExpressionStructLiteral expression )
private TypedNodeExpressionStructLiteral CheckExpressionStructLiteral ( NodeExpressionStructLiteral expression , NubType ? expectedType )
{
{
var type = ResolveType ( expression . Type ) ;
if ( expression . Type ! = null )
if ( type is not NubTypeStruct structType )
throw BasicError ( "Type of struct literal is not a struct" , expression . Type ) ;
if ( ! moduleGraph . TryResolveType ( structType . Module , structType . Name , structType . Module = = currentModule , out var info ) )
throw BasicError ( $"Type '{structType}' struct literal not found" , expression . Type ) ;
if ( info is not Module . TypeInfoStruct structInfo )
throw BasicError ( $"Type '{structType}' is not a struct" , expression . Type ) ;
var initializers = new List < TypedNodeExpressionStructLiteral . Initializer > ( ) ;
foreach ( var initializer in expression . Initializers )
{
{
var field = structInfo . Fields . FirstOrDefault ( x = > x . Name = = initializer . Name . Ident ) ;
var type = ResolveType ( expression . Type ) ;
if ( field = = null )
if ( type is not NubTypeStruct structType )
throw BasicError ( $"Field '{initia liz er.Name.Ident}' does not exist on struct '{structType.Module}::{structType.Name}'" , initializer . Name ) ;
throw BasicError ( "Type of struct lit eral is not a struct" , expression ) ;
if ( ! moduleGraph . TryResolveType ( structType . Module , structType . Name , structType . Module = = currentModule , out var info ) )
throw BasicError ( $"Type '{structType}' struct literal not found" , expression ) ;
var value = CheckExpression ( initializer . Value ) ;
if ( info is not Module . TypeInfoStruct structInfo )
if ( ! value . Type . IsAssignableTo ( field . Type ) )
throw BasicError ( $"Type '{structType}' is not a struct" , expression . Type ) ;
throw BasicError ( $"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})" , initializer . Name ) ;
initializers . Add ( new TypedNodeExpressionStructLiteral . Initializer ( initializer . Tokens , initializer . Name , value ) ) ;
var initializers = new List < TypedNodeExpressionStructLiteral . Initializer > ( ) ;
foreach ( var initializer in expression . Initializers )
{
var field = structInfo . Fields . FirstOrDefault ( x = > x . Name = = initializer . Name . Ident ) ;
if ( field = = null )
throw BasicError ( $"Field '{initializer.Name.Ident}' does not exist on struct '{structType.Module}::{structType.Name}'" , initializer . Name ) ;
var value = CheckExpression ( initializer . Value , field . Type ) ;
if ( ! value . Type . IsAssignableTo ( field . Type ) )
throw BasicError ( $"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})" , initializer . Name ) ;
initializers . Add ( new TypedNodeExpressionStructLiteral . Initializer ( initializer . Tokens , initializer . Name , value ) ) ;
}
return new TypedNodeExpressionStructLiteral ( expression . Tokens , structType , initializers ) ;
}
}
else if ( expectedType is NubTypeStruct structType )
{
if ( ! moduleGraph . TryResolveType ( structType . Module , structType . Name , structType . Module = = currentModule , out var info ) )
throw BasicError ( $"Type '{structType}' struct literal not found" , expression ) ;
return new TypedNodeExpressionStructLiteral ( expression . Tokens , NubTypeStruct . Get ( structType . Module , structType . Name ) , initializers ) ;
if ( info is not Module . TypeInfoStruct structInfo )
throw BasicError ( $"Type '{structType}' is not a struct" , expression ) ;
var initializers = new List < TypedNodeExpressionStructLiteral . Initializer > ( ) ;
foreach ( var initializer in expression . Initializers )
{
var field = structInfo . Fields . FirstOrDefault ( x = > x . Name = = initializer . Name . Ident ) ;
if ( field = = null )
throw BasicError ( $"Field '{initializer.Name.Ident}' does not exist on struct '{structType.Module}::{structType.Name}'" , initializer . Name ) ;
var value = CheckExpression ( initializer . Value , field . Type ) ;
if ( ! value . Type . IsAssignableTo ( field . Type ) )
throw BasicError ( $"Type of assignment ({value.Type}) does not match expected type of field '{field.Name}' ({field.Type})" , initializer . Name ) ;
initializers . Add ( new TypedNodeExpressionStructLiteral . Initializer ( initializer . Tokens , initializer . Name , value ) ) ;
}
return new TypedNodeExpressionStructLiteral ( expression . Tokens , structType , initializers ) ;
}
// todo(nub31): Infer anonymous struct types if expectedType is anonymous struct
else
{
var initializers = new List < TypedNodeExpressionStructLiteral . Initializer > ( ) ;
foreach ( var initializer in expression . Initializers )
{
var value = CheckExpression ( initializer . Value , null ) ;
initializers . Add ( new TypedNodeExpressionStructLiteral . Initializer ( initializer . Tokens , initializer . Name , value ) ) ;
}
var type = NubTypeAnonymousStruct . Get ( initializers . Select ( x = > new NubTypeAnonymousStruct . Field ( x . Name . Ident , x . Value . Type ) ) . ToList ( ) ) ;
return new TypedNodeExpressionStructLiteral ( expression . Tokens , type , initializers ) ;
}
}
}
private TypedNodeExpressionEnumLiteral CheckExpressionEnumLiteral ( NodeExpressionEnumLiteral expression )
private TypedNodeExpressionEnumLiteral CheckExpressionEnumLiteral ( NodeExpressionEnumLiteral expression , NubType ? expectedType )
{
{
var value = CheckExpression ( expression . Value ) ;
// todo(nub31): Infer type of enum variant
return new TypedNodeExpressionEnumLiteral ( expression . Tokens , NubTypeEnumVariant . Get ( NubTypeEnum . Get ( expression . Module . Ident , expression . EnumName . Ident ) , expression . VariantName . Ident ) , value ) ;
var type = NubTypeEnumVariant . Get ( NubTypeEnum . Get ( expression . Module . Ident , expression . EnumName . Ident ) , expression . VariantName . Ident ) ;
var value = CheckExpression ( expression . Value , null ) ;
return new TypedNodeExpressionEnumLiteral ( expression . Tokens , type , value ) ;
}
}
private NubType ResolveType ( NodeType node )
private NubType ResolveType ( NodeType node )