Skip to content

Commit 105dc44

Browse files
authored
#16 [asm] save context to stack and recusrive lambdas work (#22)
2 parents 793b2a9 + 538394b commit 105dc44

File tree

28 files changed

+585
-185
lines changed

28 files changed

+585
-185
lines changed

sleepy/asmik/emit.py

+54-11
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import override
22

33
import sleepy.tafka.representation as taf
4-
from sleepy.tafka.walker import TafkaWalker
4+
from sleepy.tafka import Context, TafkaWalker, Usages
55

66
from .argument import Immediate, Integer, Unassigned
77
from .argument import PhysicalRegister as PhysReg
@@ -20,6 +20,7 @@
2020
Orb,
2121
Remi,
2222
Slti,
23+
Stor,
2324
Xorb,
2425
mov,
2526
movi,
@@ -42,17 +43,27 @@ def temporary(self) -> VirtReg:
4243
return next(self.sequence)
4344

4445

45-
class AsmikEmitListener(TafkaWalker.Listener):
46+
class AsmikEmitListener(TafkaWalker.ContextedListener):
4647
def __init__(self) -> None:
48+
super().__init__()
49+
4750
self.memory = Memory()
4851
self.registers = VirtualRegisters()
4952

5053
self.resolved: dict[str, int] = {}
5154

55+
self.procedure: taf.Procedure
56+
self.usages: Usages
57+
5258
@override
5359
def enter_procedure(self, procedure: taf.Procedure) -> None:
60+
super().enter_procedure(procedure)
61+
62+
self.usages = Usages.analyzed(procedure)
63+
self.procedure = procedure
64+
5465
addr = self.memory.data_put(IntegerData(self.next_instr_addr))
55-
self.resolved[repr(procedure.const)] = addr
66+
self.resolved[f"${procedure.const.name}"] = addr
5667
for i, param in enumerate(procedure.parameters):
5768
register = self.registers.binded_to(param)
5869
self.emit(mov(register, PhysReg.arg(i + 1)))
@@ -63,6 +74,7 @@ def exit_procedure(self, procedure: taf.Procedure) -> None:
6374

6475
@override
6576
def enter_block(self, block: taf.Block) -> None:
77+
super().enter_block(block)
6678
self.resolved[repr(block.label)] = self.next_instr_addr
6779

6880
@override
@@ -71,7 +83,7 @@ def exit_block(self, block: taf.Block) -> None:
7183

7284
@override
7385
def enter_statement(self, statement: taf.Statement) -> None:
74-
pass
86+
super().enter_statement(statement)
7587

7688
@override
7789
def exit_statement(self, statement: taf.Statement) -> None:
@@ -98,27 +110,48 @@ def on_conditional(self, conditional: taf.Conditional) -> None:
98110
self.emit(movi(else_address, Unassigned(else_label)))
99111
self.emit(Brn(condition, else_address))
100112

113+
def push_context(self, variables: list[taf.Var]) -> None:
114+
def push(register: Reg) -> None:
115+
self.emit(Stor(Reg.sp(), register))
116+
self.emit(Addim(Reg.sp(), Reg.sp(), Integer(8)))
117+
118+
for local in variables:
119+
if self.is_alive(local):
120+
push(self.registers.binded_to(local))
121+
push(Reg.ra())
122+
123+
def pop_context(self, variables: list[taf.Var]) -> None:
124+
def pop(register: Reg) -> None:
125+
self.emit(Addim(Reg.sp(), Reg.sp(), Integer(-8)))
126+
self.emit(Load(register, Reg.sp()))
127+
128+
pop(Reg.ra())
129+
for local in variables[::-1]:
130+
if self.is_alive(local):
131+
pop(self.registers.binded_to(local))
132+
101133
@override
102134
def on_invokation(
103135
self,
104136
target: taf.Var,
105137
source: taf.Invokation,
106138
) -> None:
139+
variables = list(self.procedure.locals)
140+
141+
self.push_context(variables)
142+
107143
for i, arg in enumerate(source.args):
108144
arg_reg = self.registers.binded_to(arg)
109145
self.emit(mov(PhysReg.arg(i + 1), arg_reg))
110146

111-
prev_ra = self.registers.temporary()
112-
self.emit(mov(prev_ra, Reg.ra()))
113-
114147
proc_reg = self.registers.binded_to(source.closure)
115148
self.emit(Addim(Reg.ra(), Reg.ip(), Integer(4)))
116149
self.emit(Brn(Reg.ze(), proc_reg))
117150

118-
res_reg = self.registers.binded_to(target)
119-
self.emit(mov(res_reg, Reg.a1()))
151+
result = self.registers.binded_to(target)
152+
self.emit(mov(result, Reg.a1()))
120153

121-
self.emit(mov(Reg.ra(), prev_ra))
154+
self.pop_context(variables)
122155

123156
@override
124157
def on_load(self, target: taf.Var, source: taf.Load) -> None:
@@ -215,10 +248,20 @@ def addr_of(self, cnst: taf.Const) -> Immediate:
215248
addr = self.memory.data_put(data)
216249
return Integer(addr)
217250
case taf.Signature():
218-
return Unassigned(repr(cnst))
251+
return Unassigned(f"${cnst.name}")
219252
case _:
220253
raise NotImplementedError
221254

255+
def is_alive(self, var: taf.Var) -> bool:
256+
read = self.usages.next_read(var, self.context)
257+
write = self.usages.next_write(var, self.context)
258+
init = self.usages.next_write(var, Context(-1, self.procedure.entry))
259+
260+
assert init is not None # noqa: S101
261+
if write is not None:
262+
return read is not None and read <= write
263+
return read is not None and init < self.position
264+
222265
@property
223266
def next_instr_addr(self) -> int:
224267
return len(self.memory.instr) * 4

sleepy/asmik/unit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def resolve_addresses(asmik: AsmikEmitListener) -> None:
2626
asmik = AsmikEmitListener()
2727
walker = TafkaWalker(asmik)
2828

29-
walker.explore_block(tafka.main)
29+
walker.explore_procedure(tafka.main)
3030
for proc in tafka.procedures:
3131
walker.explore_procedure(proc)
3232

sleepy/interpreter/asmik.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(self) -> None:
3535
self.registers["ze"] = 0
3636
self.registers["ip"] = 0
3737
self.registers["ra"] = self.STOP
38+
self.registers["sp"] = 10000
3839

3940
self.running = False
4041

@@ -101,9 +102,7 @@ def read(self, reg: Register) -> int:
101102
return self.registers[repr(reg)]
102103

103104
def write(self, reg: Register, value: int) -> None:
104-
match reg:
105-
case "ze":
106-
message = "ze is readonly"
107-
raise SleepyError(message)
108-
case _:
109-
self.registers[repr(reg)] = value
105+
if reg == Register.ze():
106+
message = "ze is readonly"
107+
raise SleepyError(message)
108+
self.registers[repr(reg)] = value

sleepy/main.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66

77
def main() -> None:
88
source = """
9-
(def id (lambda (n int) n))
10-
(def a (id 1))
11-
(def b (id 11))
12-
(def c (id 111))
13-
(if (and (eq a 1)
14-
(and (eq b 11)
15-
(eq c 111))) 1 0)
9+
(def fibb (lambda (n int)
10+
(if (or (eq n 0) (eq n 1))
11+
1
12+
(sum
13+
(self (sum n -1))
14+
(self (sum n -2))))))
15+
(fibb 13)
1616
"""
1717

1818
parser = LarkParser()

sleepy/program/representation.py

+3
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,11 @@ class Intrinsic(Atom):
5454

5555
@dataclass
5656
class Closure(Expression):
57+
from .namespace import Namespace
58+
5759
parameters: list[Parameter]
5860
statements: list[Expression]
61+
namespace: Namespace
5962

6063

6164
@dataclass

sleepy/syntax/s2p.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def visit_lambda(self, tree: LambdaAST) -> Closure:
7878
for parameter in parameters:
7979
self.bindings.bind(parameter.name, parameter)
8080

81-
closure = Closure(parameters, statements=[])
81+
closure = Closure(parameters, statements=[], namespace=self.namespace)
8282
self.bindings.bind(self.namespace.define(Symbol("self")), closure)
8383

8484
closure.statements = [

sleepy/tafka/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@
3333
)
3434
from .unit import TafkaUnit
3535
from .usage import Usages
36-
from .walker import TafkaWalker
36+
from .walker import Context, TafkaWalker

sleepy/tafka/emit.py

+37-11
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,29 @@ class TafkaEmitVisitor(ProgramVisitor[None]):
1010
def __init__(self, unit: ProgramUnit) -> None:
1111
self.unit = unit
1212

13-
self.main = taf.Block(taf.Label("main"), [])
13+
self.main = taf.Procedure(
14+
name="main",
15+
entry=taf.Block(taf.Label("main"), []),
16+
parameters=[],
17+
value=taf.Unknown(),
18+
)
1419
self.procedures: list[taf.Procedure] = []
1520

1621
self.var_names = map(str, range(10000000))
1722
self.lbl_names = map(str, range(10000000))
1823

1924
self.vars = MetaTable[taf.Var]()
2025

21-
self.current_block = self.main
26+
self.current_procedure = self.main
27+
self.current_block = self.current_procedure.entry
2228
self.last_result = taf.Var("0", taf.Int())
2329

2430
@override
2531
def visit_program(self, tree: program.Program) -> None:
2632
for statement in tree.statements:
2733
self.visit_expression(statement)
2834
self.emit_statement(taf.Return(self.last_result))
35+
self.main.value = self.last_result.kind
2936

3037
@override
3138
def visit_conditional(self, tree: program.Conditional) -> None:
@@ -116,31 +123,46 @@ def visit_application_variable(
116123

117124
@override
118125
def visit_lambda(self, tree: program.Closure) -> None:
126+
current_procedure = self.current_procedure
119127
current_block = self.current_block
120128

121129
label = self.next_lbl()
122130

123-
params = [
131+
procedure = taf.Procedure(
132+
name=label.name,
133+
entry=taf.Block(label, statements=[]),
134+
parameters=[],
135+
value=taf.Unknown(),
136+
)
137+
138+
self.current_procedure = procedure
139+
self.current_block = procedure.entry
140+
141+
procedure.parameters = [
124142
self.next_var(taf.Kind.from_sleepy(param.kind))
125143
for param in tree.parameters
126144
]
127145

128-
for param, var in zip(tree.parameters, params, strict=True):
146+
for param, var in zip(
147+
tree.parameters,
148+
procedure.parameters,
149+
strict=True,
150+
):
129151
self.vars[param.name] = var
130152

131-
body = taf.Block(label, statements=[])
153+
self.emit_intermidiate(
154+
taf.Load(taf.Const(label.name, procedure.signature)),
155+
)
156+
self.vars[tree.namespace.resolved("self")] = self.last_result
132157

133-
self.current_block = body
134158
for statement in tree.statements:
135159
self.visit_expression(statement)
136-
137160
self.emit_statement(taf.Return(self.last_result))
138-
139-
value = self.last_result.kind
161+
procedure.value = self.last_result.kind
140162

141163
self.current_block = current_block
164+
self.current_procedure = current_procedure
142165

143-
procedure = taf.Procedure(label.name, body, params, value)
144166
self.procedures.append(procedure)
145167

146168
self.emit_intermidiate(
@@ -167,13 +189,17 @@ def visit_definition(self, tree: program.Definition) -> None:
167189
def emit_statement(self, statement: taf.Statement) -> None:
168190
if isinstance(statement, taf.Set):
169191
self.last_result = statement.target
192+
if isinstance(statement, taf.Return):
193+
self.current_procedure.value = self.last_result.kind
170194
self.current_block.statements.append(statement)
171195

172196
def emit_intermidiate(self, rvalue: taf.RValue) -> None:
173197
self.emit_statement(taf.Set(self.next_var(rvalue.value), rvalue))
174198

175199
def next_var(self, kind: taf.Kind) -> taf.Var:
176-
return taf.Var(next(self.var_names), kind)
200+
var = taf.Var(next(self.var_names), kind)
201+
self.current_procedure.locals.append(var)
202+
return var
177203

178204
def next_lbl(self) -> taf.Label:
179205
return taf.Label(next(self.lbl_names))

sleepy/tafka/representation/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
Set,
1010
Statement,
1111
)
12-
from .kind import Int, Kind, Signature
12+
from .kind import Int, Kind, Signature, Unknown
1313
from .node import Node
1414
from .rvalue import (
1515
And,

sleepy/tafka/representation/block.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, field
22
from typing import cast
33

44
from .kind import Kind, Signature
@@ -60,6 +60,7 @@ class Procedure(Node):
6060
entry: Block
6161
parameters: list[Var]
6262
value: Kind
63+
locals: list[Var] = field(default_factory=list, init=False)
6364

6465
@property
6566
def signature(self) -> Signature:

sleepy/tafka/representation/kind.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,21 @@ def from_sleepy(cls, kind: SleepyKind) -> "Kind":
1818
raise NotImplementedError
1919

2020

21-
@dataclass(repr=False)
21+
@dataclass(repr=False, unsafe_hash=True)
22+
class Unknown(Kind):
23+
@override
24+
def __repr__(self) -> str:
25+
return "?"
26+
27+
28+
@dataclass(repr=False, unsafe_hash=True)
2229
class Int(Kind):
2330
@override
2431
def __repr__(self) -> str:
2532
return "int"
2633

2734

28-
@dataclass(repr=False)
35+
@dataclass(repr=False, unsafe_hash=True)
2936
class Bool(Kind):
3037
@override
3138
def __repr__(self) -> str:
@@ -37,9 +44,10 @@ class Signature(Kind):
3744
params: list[Kind]
3845
value: Kind
3946

47+
@override
48+
def __hash__(self) -> int:
49+
return hash(str(self.params)) + hash(self.value)
50+
4051
@override
4152
def __repr__(self) -> str:
42-
return (
43-
f"({', '.join(repr(_) for _ in self.params)}) "
44-
f"-> {self.value!r}"
45-
)
53+
return f"({', '.join(repr(_) for _ in self.params)}) -> {self.value!r}"

0 commit comments

Comments
 (0)