Skip to content

Commit ee2c50a

Browse files
authoredFeb 19, 2025
Add types.Unalias to types assertions and types switches to get an underlying type instead of types.Alias (#33868)
1 parent 7f08bba commit ee2c50a

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed
 

‎sdks/go/pkg/beam/util/starcgenx/starcgenx.go

+28-13
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
// Package starcgenx is a Static Analysis Type Assertion shim and Registration Code Generator
1717
// which provides an extractor to extract types from a package, in order to generate
18-
// approprate shimsr a package so code can be generated for it.
18+
// appropriate shims for a package so code can be generated for it.
1919
//
2020
// It's written for use by the starcgen tool, but separate to permit
2121
// alternative "go/importer" Importers for accessing types from imported packages.
@@ -336,6 +336,7 @@ func (e *Extractor) isRequired(ident string, obj types.Object, idsRequired, idsF
336336
// or it's receiver type identifier needs to be in the filtered identifiers.
337337
if idsRequired[ident] {
338338
idsFound[ident] = true
339+
e.Printf("isRequired found: %s\n", ident)
339340
return true
340341
}
341342
// Check if this is a function.
@@ -347,10 +348,10 @@ func (e *Extractor) isRequired(ident string, obj types.Object, idsRequired, idsF
347348
if recv := sig.Recv(); recv != nil && graph.IsLifecycleMethod(ident) {
348349
// We don't want to care about pointers, so dereference to value type.
349350
t := recv.Type()
350-
p, ok := t.(*types.Pointer)
351+
p, ok := types.Unalias(t).(*types.Pointer)
351352
for ok {
352353
t = p.Elem()
353-
p, ok = t.(*types.Pointer)
354+
p, ok = types.Unalias(t).(*types.Pointer)
354355
}
355356
ts := types.TypeString(t, e.qualifier)
356357
e.Printf("recv %v has %v, ts: %s %s--- ", recv, sig, ts, ident)
@@ -384,14 +385,16 @@ func (e *Extractor) fromObj(fset *token.FileSet, id *ast.Ident, obj types.Object
384385
ident = obj.Name()
385386
}
386387
if !e.isRequired(ident, obj, idsRequired, idsFound) {
388+
e.Printf("%s: %q with package %q is not required \n",
389+
fset.Position(id.Pos()), id.Name, pkg.Name())
387390
return
388391
}
389392

390393
switch ot := obj.(type) {
391394
case *types.Var:
392395
// Vars are tricky since they could be anything, and anywhere (package scope, parameters, etc)
393396
// eg. Flags, or Field Tags, among others.
394-
// I'm increasingly convinced that we should simply igonore vars.
397+
// I'm increasingly convinced that we should simply ignore vars.
395398
// Do nothing for vars.
396399
case *types.Func:
397400
sig := obj.Type().(*types.Signature)
@@ -405,10 +408,10 @@ func (e *Extractor) fromObj(fset *token.FileSet, id *ast.Ident, obj types.Object
405408
}
406409
// This must be a structural DoFn! We should generate a closure wrapper for it.
407410
t := recv.Type()
408-
p, ok := t.(*types.Pointer)
411+
p, ok := types.Unalias(t).(*types.Pointer)
409412
for ok {
410413
t = p.Elem()
411-
p, ok = t.(*types.Pointer)
414+
p, ok = types.Unalias(t).(*types.Pointer)
412415
}
413416
ts := types.TypeString(t, e.qualifier)
414417
mthdMap := e.wraps[ts]
@@ -453,6 +456,10 @@ func (e *Extractor) extractType(ot *types.TypeName) {
453456
// A single level is safe since the code we're analysing imports it,
454457
// so we can assume the generated code can access it too.
455458
if ot.IsAlias() {
459+
if t, ok := ot.Type().(*types.Alias); ok {
460+
ot = t.Obj()
461+
name = types.TypeString(t, e.qualifier)
462+
}
456463
if t, ok := ot.Type().(*types.Named); ok {
457464
ot = t.Obj()
458465
name = types.TypeString(t, e.qualifier)
@@ -461,7 +468,7 @@ func (e *Extractor) extractType(ot *types.TypeName) {
461468
// Only register non-universe types (eg. avoid `error` and similar)
462469
if pkg := ot.Pkg(); pkg != nil {
463470
path := pkg.Path()
464-
e.imports[pkg.Path()] = struct{}{}
471+
e.imports[path] = struct{}{}
465472

466473
// Do not add universal types to be registered.
467474
if path == shimx.TypexImport {
@@ -484,17 +491,17 @@ func (e *Extractor) extractFromContainer(t types.Type) types.Type {
484491
// Container types need to be iteratively unwrapped until we're at the base type,
485492
// so we can get the import if necessary.
486493
for {
487-
if s, ok := t.(*types.Slice); ok {
494+
if s, ok := types.Unalias(t).(*types.Slice); ok {
488495
t = s.Elem()
489496
continue
490497
}
491498

492-
if p, ok := t.(*types.Pointer); ok {
499+
if p, ok := types.Unalias(t).(*types.Pointer); ok {
493500
t = p.Elem()
494501
continue
495502
}
496503

497-
if a, ok := t.(*types.Array); ok {
504+
if a, ok := types.Unalias(t).(*types.Array); ok {
498505
t = a.Elem()
499506
continue
500507
}
@@ -510,9 +517,17 @@ func (e *Extractor) extractFromTuple(tuple *types.Tuple) {
510517
t := e.extractFromContainer(s.Type())
511518

512519
// Here's where we ensure we register new imports.
520+
if at, ok := t.(*types.Alias); ok {
521+
if pkg := at.Obj().Pkg(); pkg != nil {
522+
e.imports[pkg.Path()] = struct{}{}
523+
}
524+
}
513525
if t, ok := t.(*types.Named); ok {
514526
if pkg := t.Obj().Pkg(); pkg != nil {
527+
e.Printf("extractType: adding import path %q for %v\n", pkg.Path(), t)
515528
e.imports[pkg.Path()] = struct{}{}
529+
} else {
530+
e.Printf("extractType: %v has no package to import\n", t)
516531
}
517532
e.extractType(t.Obj())
518533
}
@@ -683,7 +698,7 @@ func (e *Extractor) makeEmitter(sig *types.Signature) (shimx.Emitter, bool) {
683698

684699
// makeInput checks if the given signature is an iterator or not, and if so,
685700
// returns a shimx.Input struct for the signature for use by the code
686-
// generator. The canonical check for an iterater signature is in the
701+
// generator. The canonical check for an iterator signature is in the
687702
// funcx.UnfoldIter function which uses the reflect library,
688703
// and this logic is replicated here.
689704
func (e *Extractor) makeInput(sig *types.Signature) (shimx.Input, bool) {
@@ -692,13 +707,13 @@ func (e *Extractor) makeInput(sig *types.Signature) (shimx.Input, bool) {
692707
return shimx.Input{}, false
693708
}
694709
// Iterators must return a bool.
695-
if b, ok := r.At(0).Type().(*types.Basic); !ok || b.Kind() != types.Bool {
710+
if b, ok := types.Unalias(r.At(0).Type()).(*types.Basic); !ok || b.Kind() != types.Bool {
696711
return shimx.Input{}, false
697712
}
698713
p := sig.Params()
699714
for i := 0; i < p.Len(); i++ {
700715
// All params for iterators must be pointers.
701-
if _, ok := p.At(i).Type().(*types.Pointer); !ok {
716+
if _, ok := types.Unalias(p.At(i).Type()).(*types.Pointer); !ok {
702717
return shimx.Input{}, false
703718
}
704719
}

0 commit comments

Comments
 (0)