17
17
18
18
ClassifiedImport = namedtuple (
19
19
'ClassifiedImport' ,
20
- ['type' , 'is_from' , 'modules' , 'names' , 'lineno' , 'level' , 'package' ],
20
+ ['type' , 'is_from' , 'modules' , 'names' , 'lineno' , 'level' , 'package' ,
21
+ 'type_checking' ],
21
22
)
22
23
NewLine = namedtuple ('NewLine' , ['lineno' ])
23
24
@@ -70,8 +71,13 @@ def __init__(self, application_import_names, application_package_names):
70
71
self .application_import_names = frozenset (application_import_names )
71
72
self .application_package_names = frozenset (application_package_names )
72
73
74
+ def generic_visit (self , node ):
75
+ for child in ast .iter_child_nodes (node ):
76
+ child .parent = node
77
+ return super ().generic_visit (node )
78
+
73
79
def visit_Import (self , node ): # noqa: N802
74
- if node .col_offset == 0 :
80
+ if node .col_offset == 0 or self . _type_checking_import ( node ) :
75
81
modules = [alias .name for alias in node .names ]
76
82
types_ = {self ._classify_type (module ) for module in modules }
77
83
if len (types_ ) == 1 :
@@ -81,11 +87,12 @@ def visit_Import(self, node): # noqa: N802
81
87
classified_import = ClassifiedImport (
82
88
type_ , False , modules , [], node .lineno , 0 ,
83
89
root_package_name (modules [0 ]),
90
+ self ._type_checking_import (node ),
84
91
)
85
92
self .imports .append (classified_import )
86
93
87
94
def visit_ImportFrom (self , node ): # noqa: N802
88
- if node .col_offset == 0 :
95
+ if node .col_offset == 0 or self . _type_checking_import ( node ) :
89
96
module = node .module or ''
90
97
if node .level > 0 :
91
98
type_ = ImportType .APPLICATION_RELATIVE
@@ -96,9 +103,16 @@ def visit_ImportFrom(self, node): # noqa: N802
96
103
type_ , True , [module ], names ,
97
104
node .lineno , node .level ,
98
105
root_package_name (module ),
106
+ self ._type_checking_import (node ),
99
107
)
100
108
self .imports .append (classified_import )
101
109
110
+ def _type_checking_import (self , node ):
111
+ return (
112
+ isinstance (node .parent , ast .If )
113
+ and node .parent .test .id == "TYPE_CHECKING"
114
+ )
115
+
102
116
def _classify_type (self , module ):
103
117
package_names = get_package_names (module )
104
118
0 commit comments