Skip to content

Commit 3a00dbc

Browse files
committed
additional input type tests
1 parent 48502b9 commit 3a00dbc

File tree

4 files changed

+62
-2
lines changed

4 files changed

+62
-2
lines changed

btctxstore/api.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def validate_address(self, address): # TODO test
3737
try:
3838
deserialize.address(self.testnet, address)
3939
return True
40-
except exceptions.InvalidAddress:
40+
except exceptions.InvalidInput:
4141
return False
4242

4343
def get_address(self, wif):
@@ -48,7 +48,7 @@ def validate_key(self, wif): # TODO test
4848
try:
4949
deserialize.key(self.testnet, wif)
5050
return True
51-
except exceptions.InvalidWif:
51+
except exceptions.InvalidInput:
5252
return False
5353

5454
###############

btctxstore/deserialize.py

+8
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ def unicode_str(string):
3232
return string
3333

3434

35+
def string(s):
36+
if type(s) != type("string"):
37+
raise exceptions.InvalidInput("Must be a string!")
38+
return s
39+
40+
3541
def tx(rawtx):
3642
return Tx.tx_from_hex(rawtx)
3743

@@ -79,6 +85,7 @@ def txid(txhash):
7985

8086

8187
def address(testnet, address):
88+
address = string(address)
8289
netcode = 'XTN' if testnet else 'BTC'
8390
if not validate.is_address_valid(address, allowable_netcodes=[netcode]):
8491
raise exceptions.InvalidAddress(address)
@@ -138,6 +145,7 @@ def secret_exponents(testnet, wifs):
138145

139146

140147
def key(testnet, wif):
148+
wif = string(wif)
141149
netcode = 'XTN' if testnet else 'BTC'
142150
if not validate.is_wif_valid(wif, allowable_netcodes=[netcode]):
143151
raise exceptions.InvalidWif(wif)

tests/api.py

+12
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,9 @@ def test_invalid_network(self):
343343
def test_invalid_data(self):
344344
self.assertFalse(self.testnet_api.validate_address("f483"))
345345

346+
def test_invalid_type(self):
347+
self.assertFalse(self.testnet_api.validate_address(None))
348+
346349

347350
class TestValidateAddressMainnet(unittest.TestCase):
348351

@@ -361,6 +364,9 @@ def test_invalid_network(self):
361364
def test_invalid_data(self):
362365
self.assertFalse(self.mainnet_api.validate_address("f483"))
363366

367+
def test_invalid_type(self):
368+
self.assertFalse(self.mainnet_api.validate_address(None))
369+
364370

365371
class TestValidateKeyTestnet(unittest.TestCase):
366372

@@ -379,6 +385,9 @@ def test_invalid_network(self):
379385
def test_invalid_data(self):
380386
self.assertFalse(self.testnet_api.validate_key("f483"))
381387

388+
def test_invalid_type(self):
389+
self.assertFalse(self.testnet_api.validate_key(None))
390+
382391

383392
class TestValidateKeyMainnet(unittest.TestCase):
384393

@@ -397,6 +406,9 @@ def test_invalid_network(self):
397406
def test_invalid_data(self):
398407
self.assertFalse(self.mainnet_api.validate_key("f483"))
399408

409+
def test_invalid_type(self):
410+
self.assertFalse(self.mainnet_api.validate_key(None))
411+
400412

401413
if __name__ == '__main__':
402414
unittest.main()

tests/deserialize.py

+40
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,26 @@ def callback():
6767
deserialize.key(True, mainnet_wif)
6868
self.assertRaises(exceptions.InvalidWif, callback)
6969

70+
# testnet non string
71+
def callback():
72+
deserialize.key(False, None)
73+
self.assertRaises(exceptions.InvalidInput, callback)
74+
75+
# mainnet non string
76+
def callback():
77+
deserialize.key(True, None)
78+
self.assertRaises(exceptions.InvalidInput, callback)
79+
80+
# testnet garbage string
81+
def callback():
82+
deserialize.key(False, "garbage")
83+
self.assertRaises(exceptions.InvalidWif, callback)
84+
85+
# mainnet garbage string
86+
def callback():
87+
deserialize.key(True, "garbage")
88+
self.assertRaises(exceptions.InvalidWif, callback)
89+
7090

7191
class TestAddress(unittest.TestCase):
7292

@@ -91,3 +111,23 @@ def callback():
91111
def callback():
92112
deserialize.address(True, mainnet_address)
93113
self.assertRaises(exceptions.InvalidAddress, callback)
114+
115+
# non string testnet
116+
def callback():
117+
deserialize.address(False, None)
118+
self.assertRaises(exceptions.InvalidInput, callback)
119+
120+
# non string mainnet
121+
def callback():
122+
deserialize.address(True, None)
123+
self.assertRaises(exceptions.InvalidInput, callback)
124+
125+
# garbage string testnet
126+
def callback():
127+
deserialize.address(False, "garbage")
128+
self.assertRaises(exceptions.InvalidAddress, callback)
129+
130+
# garbage string mainnet
131+
def callback():
132+
deserialize.address(True, "garbage")
133+
self.assertRaises(exceptions.InvalidAddress, callback)

0 commit comments

Comments
 (0)