Skip to content

Commit

Permalink
Merge pull request #3 from arhik/12-all-new-arguments-are-var-type
Browse files Browse the repository at this point in the history
var fix
  • Loading branch information
arhik committed Mar 21, 2024
2 parents 58cf094 + 9f7f21b commit f4ff1a9
Show file tree
Hide file tree
Showing 12 changed files with 260 additions and 86 deletions.
1 change: 1 addition & 0 deletions src/codegen/abstracts.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
abstract type JLExpr end
abstract type JLVariable <: JLExpr end
abstract type JLBlock <: JLExpr end
abstract type BinaryOp <: JLExpr end
abstract type JLBuiltIn <: JLExpr end
146 changes: 110 additions & 36 deletions src/codegen/assignment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@ typeInfer(scope::Scope, varRef::Ref{WGPUVariable}) = typeInfer(scope, varRef[])

struct LHS
expr::Union{Ref{WGPUVariable}, JLExpr}
newVar::Bool
end

isMutable(lhs::LHS) = isMutable(lhs.expr)
setMutable!(lhs::LHS, b::Bool) = setMutable!(lhs.expr, b)


isNew(lhs::LHS) = isNew(lhs.expr)
setNew!(lhs::LHS, b::Bool) = setNew!(lhs.expr, b)
isNew(lhs::LHS) = lhs.newVar

symbol(lhs::LHS) = symbol(lhs.expr)

Expand Down Expand Up @@ -70,24 +69,92 @@ end

symbol(assign::AssignmentExpr) = symbol(assign.lhs)

function assignExpr(scope, lhs::Symbol, rhs::Number)
rhsExpr = RHS(Scalar(rhs))
rhsType = typeInfer(scope, rhsExpr)
(lhsfound, lhslocation, lhsScope) = findVar(scope, lhs)
lhsExpr = Ref{LHS}()
if lhsfound && lhslocation == :localScope
lExpr = lhsScope.locals[lhs]
lhsExpr[] = LHS(lExpr, false)
rhsExpr = RHS(Scalar(rhs |> lExpr[].dataType))
setMutable!(lhsExpr[], true)
elseif lhsfound && lhslocation == :globalScope
lVar = lhsScope.globals[lhs]
lVarRef = Ref{WGPUVariable}(lVar)
lVarRef[].dataType = rhsType
lhsExprp[] = LHS(lVarRef, false)
setMutable!(lhsExpr[], true)
elseif lhsfound == false
lvar = inferExpr(scope, lhs)
scope.globals[lhs] = lvar[]
lvar[].dataType = rhsType
scope.locals[lhs] = lvar
#lvar[].undefined = true
lhsExpr[] = LHS(lvar, true)
setMutable!(lhsExpr[], false)
else
error("This should not have reached")
end
statement = AssignmentExpr(lhsExpr[], rhsExpr, scope)
return statement
end

function assignExpr(scope, lhs::Expr, rhs::Number)
lExpr = inferExpr(scope, lhs)
rhsExpr = RHS(Scalar(rhs))
rhsType = typeInfer(scope, rhsExpr)
(lhsfound, lhslocation, lhsScope) = findVar(scope, symbol(lExpr))
lhsExpr = Ref{LHS}()
if typeof(lExpr) == IndexExpr
if lhsfound && lhslocation == :localScope
lExpr = lhsScope.locals[lhs]
lhsExpr[] = LHS(lExpr, false)
rhsExpr = RHS(Scalar(rhs |> lExpr[].dataType))
setMutable!(lhsExpr[], true)
elseif lhsfound && lhslocation == :globalScope
lVar = lhsScope.globals[symbol(lExpr)]
lVarRef = Ref{WGPUVariable}(lVar)
rhsExpr = RHS(Scalar(rhs |> eltype(lVar.dataType)))
lhsExpr[] = LHS(lExpr, false)
setMutable!(lhsExpr[], true)
elseif lhsfound == false
lvar = inferExpr(scope, lhs)
scope.globals[lhs] = lvar[]
lvar[].dataType = rhsType
lhsExpr[] = LHS(lvar, true)
setMutable!(lhsExpr[], false)
else
error("This should not have reached")
end
elseif typeof(lExpr) == AccessExpr
else
error(" This expr type is not covered : $lExpr")
end
statement = AssignmentExpr(lhsExpr[], rhsExpr, scope)
return statement
end


function assignExpr(scope, lhs::Symbol, rhs::Symbol)
rhsExpr = RHS(inferExpr(scope, rhs))
rhsType = typeInfer(scope, rhsExpr)
(lhsfound, lhslocation, lhsScope) = findVar(scope, lhs)
lhsExpr = Ref{LHS}()
if lhsfound
if lhsfound && lhslocation == :localScope
lExpr = lhsScope.locals[lhs]
lhsExpr[] = LHS(lExpr)
lhsExpr[] = LHS(lExpr, false)
lExpr[].dataType = rhsType
setNew!(lhsExpr[], false)
setMutable!(lhsExpr[], true)
else
lVar = inferExpr(scope, lhs)
scope.locals[lhs] = lVar
lVar[].dataType = rhsType
lhsExpr[] = LHS(lVar)
setNew!(lhsExpr[], true)
setMutable!(lhsExpr[], false)
elseif lhsfound && lhslocation == :globalScope
lVar = lhsScope.globals[lhs]
lVarRef = Ref{WGPUVariable}(lVar)
#scope.locals[lhs] = lVarRef
lVarRef[].dataType = rhsType
lhsExpr[] = LHS(lVarRef, false)
setMutable!(lhsExpr[], true)
elseif found == false
# setMutable!(lhsExpr[], false)
end
statement = AssignmentExpr(lhsExpr[], rhsExpr, scope)
return statement
Expand All @@ -98,20 +165,30 @@ function assignExpr(scope, lhs::Symbol, rhs::Expr)
rhsType = typeInfer(scope, rhsExpr)
(found, location, rootScope) = findVar(scope, lhs)
lhsExpr = Ref{LHS}()
if found && location != :typeScope
if found && location == :localScope
lExpr = rootScope.locals[lhs]
lhsExpr[] = LHS(lExpr)
@assert lExpr[].dataType == rhsType
setNew!(lhsExpr[], false)
lhsExpr[] = LHS(lExpr, false)
@assert lExpr[].dataType == rhsType "$(lExpr[].dataType) != $rhsType"
lExpr[].undefined = false
setMutable!(lhsExpr[], true)
elseif found == false && location != :typeScope
elseif found && location == :globalScope
lExpr = rootScope.globals[lhs]
lExprRef = Ref{WGPUVariable}(lExpr)
rootScope.locals[lhs] = lExprRef
lhsExpr[] = LHS(lExprRef, false)
@assert lExprRef[].dataType == rhsType
lExprRef[].undefined = false
setMutable!(lhsExpr[], true)
elseif found == false && location == nothing
# new var
lvar = inferExpr(scope, lhs)
scope.locals[lhs] = lvar
scope.globals[lhs] = lvar[]
lvar[].dataType = rhsType
lhsExpr[] = LHS(lvar)
setNew!(lhsExpr[], true)
lvar[].undefined = true
lhsExpr[] = LHS(lvar, true)
setMutable!(lhsExpr[], false)
else
error("Not captured this case yet!!!!")
end
statement = AssignmentExpr(lhsExpr[], rhsExpr, scope)
return statement
Expand All @@ -128,38 +205,35 @@ function assignExpr(scope, lhs::Expr, rhs::Expr)
if found && location != :typeScope
lvar = location == :localScope ? rootScope.locals[symbol(lExpr)] : rootScope.globals[symbol(lExpr)]
#lvar = rootScope.locals[symbol(lExpr)]
lhsExpr[] = LHS(lExpr)
lhsExpr[] = LHS(lExpr, false)
lhsType = typeInfer(scope, lhsExpr[])
@assert lhsType == rhsType "$lhsType != $rhsType"
#setMutable!(lhsExpr[], true)
#setNew!(lhsExpr[], false)
else found == false
error("LHS var $(symbol(lhs)) had to be mutable for indexing")
end
elseif typeof(lExpr) == AccessExpr
(found, location, rootScope) = findVar(scope, symbol(lExpr))
if found && location !=:typeScope
lExpr = rootScope.locals[symbol(lExpr)]
lhsExpr[] = LHS(lExpr[])
lhsExpr[] = LHS(lExpr[], false)
setMutable!(lhsExpr[], true)
setNew!(lhsExpr[], false)
else found == false
error("LHS var $(symbol(lhs)) has to be mutable for `getproperty`")
end
elseif typeof(lExpr) == DeclExpr
(found, location, rootScope) = findVar(scope, symbol(lExpr))
if found && location !=:typeScope
lExpr = rootScope.locals[symbol(lExpr)]
lhsExpr[] = LHS(lExpr)
setMutable!(lhsExpr[], true)
setNew!(lhsExpr[], false)
if found && location ==:globalScope && location != :localScope
lhsExpr[] = LHS(lExpr, true)
scope.locals[symbol(lExpr)] = scope.globals[symbol(lExpr)]
elseif found && location ==:localScope
if rootScope.depth == scope.depth
error("Duplication definition is not allowed")
else
# set new var node
end
else found == false
lvar = scope.globals[symbol(lExpr)]
lvarRef = Ref{WGPUVariable}(lvar)
scope.locals[symbol(lExpr)] = lvarRef
lhsExpr[] = LHS(lExpr)
setMutable!(lhsExpr[], false)
setNew!(lhsExpr[], true)
error("This state shouldn't have been reached in Decl")
end
else
error("This $lhs type Expr is not captured yet")
Expand Down
7 changes: 1 addition & 6 deletions src/codegen/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ end

isMutable(idxExpr::IndexExpr) = isMutable(idxExpr.sym)
setMutable!(idxExpr::IndexExpr, b::Bool) = setMutable!(idxExpr.sym, b)
isNew(idxExpr::IndexExpr) = isNew(idxExpr.sym) # TODO this should always be false

function inferScope!(scope::Scope, jlexpr::IndexExpr)

Expand All @@ -106,7 +105,6 @@ end

isMutable(axsExpr::AccessExpr) = isMutable(axsExpr.sym)
setMutable!(axsExpr::AccessExpr, b::Bool) = setMutable!(axsExpr.sym, b)
isNew(axsExpr::AccessExpr) = isNew(axsExpr.sym) # TODO this should always be false

function accessExpr(scope::Scope, sym::Symbol, field::Symbol)
symExpr = inferExpr(scope, sym)
Expand Down Expand Up @@ -161,9 +159,8 @@ struct DeclExpr <: JLExpr
end

function declExpr(scope, a::Symbol, b::Symbol)
@infiltrate
(found, location, rootScope) = findVar(scope, a)
if found && location == :localScope
if found && location == :globalScope
error("Duplication declaration of variable $a")
end
aExpr = inferExpr(scope, a)
Expand All @@ -186,8 +183,6 @@ end

isMutable(decl::DeclExpr) = isMutable(decl.sym[])
setMutable!(decl::DeclExpr, b::Bool) = setMutable!(decl.sym[], b)
isNew(decl::DeclExpr) = isNew(decl.sym[])
setNew!(decl::DeclExpr, b::Bool) = setNew!(decl.sym[], b)

symbol(decl::DeclExpr) = symbol(decl.sym[])

Expand Down
4 changes: 2 additions & 2 deletions src/codegen/funcBlock.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
struct FuncBlock <: JLBlock
fname::WGPUVariable
fname::Ref{WGPUVariable}
fargs::Vector{DeclExpr}
Targs::Vector{WGPUVariable}
Targs::Vector{Ref{WGPUVariable}}
fbody::Vector{JLExpr}
scope::Union{Nothing, Scope}
end
Expand Down
9 changes: 6 additions & 3 deletions src/codegen/infer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ function inferExpr(scope::Scope, expr::Expr)
elseif @capture(expr, a_ == b_)
return binaryOp(scope, :>=, a, b)
elseif @capture(expr, a_ += b_)
return binaryOp(scope, :+=, a, b)
return binaryOp(scope, :+=, a, b) # TODO this should be assignment expr
elseif @capture(expr, a_ -= b_)
return binaryOp(scope, :-=, a, b)
return binaryOp(scope, :-=, a, b) # TODO this should be assignment expr
elseif @capture(expr, f_(args__))
return callExpr(scope, f, args)
elseif @capture(expr, a_::b_)
Expand Down Expand Up @@ -61,14 +61,17 @@ function inferExpr(scope::Scope, a::Symbol)
(found, location, rootScope) = findVar(scope, a)
var = Ref{WGPUVariable}()
if found == false
var[] = WGPUVariable(a, Any, Generic, nothing, false, false)
var[] = WGPUVariable(a, Any, Generic, nothing, false, true)
scope.globals[a] = var[]
elseif found == true && location == :globalScope
var[] = rootScope.globals[a]
var[].undefined = true
elseif found == true && location == :typeScope
var[] = rootScope.typeVars[a]
var[].undefined = false
else found == true && location == :localScope
var = rootScope.locals[a]
var[].undefined = false
end
return var
end
Expand Down
19 changes: 10 additions & 9 deletions src/codegen/rangeBlock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ struct RangeBlock <: JLBlock
start::Union{WGPUVariable, Scalar}
step::Union{WGPUVariable, Scalar}
stop::Union{WGPUVariable, Scalar}
idx::Union{WGPUVariable}
idx::Union{JLExpr}
block::Vector{JLExpr}
scope::Union{Nothing, Scope}
end

struct RangeExpr <: JLExpr
start::Union{WGPUVariable, Scalar}
step::Union{WGPUVariable, Scalar}
stop::Union{WGPUVariable, Scalar}
start::Union{Ref{WGPUVariable}, Scalar}
step::Union{Ref{WGPUVariable}, Scalar}
stop::Union{Ref{WGPUVariable}, Scalar}
end

function inferExpr(scope::Scope, range::StepRangeLen)
Expand All @@ -24,13 +24,14 @@ function rangeBlock(scope::Scope, idx::Symbol, range::Expr, block::Vector{Any})
startExpr = rangeExpr.start
stopExpr = rangeExpr.stop
stepExpr = rangeExpr.step
idxExpr = inferVariable(childScope, :($idx::UInt32))
scope.globals[idx] = idxExpr[]
scope.locals[idx] = idxExpr
inferScope!(childScope, idxExpr[])
idxExpr = inferExpr(childScope, :($idx::UInt32))
setMutable!(idxExpr, true)
scope.globals[idx] = idxExpr.sym[]
#scope.locals[idx] = idxExpr.sym
#inferScope!(childScope, idxExpr)
exprArray = JLExpr[]
for stmnt in block
push!(exprArray, inferExpr(childScope, stmnt))
end
rangeBlockExpr = RangeBlock(startExpr, stepExpr, stopExpr, idxExpr[], exprArray, childScope)
rangeBlockExpr = RangeBlock(startExpr, stepExpr, stopExpr, idxExpr, exprArray, childScope)
end
24 changes: 24 additions & 0 deletions src/codegen/scope.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,27 @@ function getDataType(scope::Union{Nothing, Scope}, var::Symbol)
return Any
end
end


function Base.isequal(scope::Scope, other::Scope)
length(scope.locals) == length(other.locals) &&
keys(scope.locals) == keys(other.locals) &&
length(scope.globals) == length(other.globals) &&
keys(scope.globals) == keys(other.globals) &&
for (key, value) in scope.locals
if !Base.isequal(other.locals[key][], value[])
return false
end
end
for (key, value) in scope.globals
if Base.isequal(other.globals[key][], value[])
return false
end
end
for (key, value) in scope.typeVars
if Base.isequal(other.typeVars[key][], value[])
return false
end
end
return true
end
Loading

0 comments on commit f4ff1a9

Please sign in to comment.