Skip to content

Commit

Permalink
reformat python style
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglirong1999 committed Oct 23, 2023
1 parent ff39e41 commit a6016d5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
30 changes: 14 additions & 16 deletions test/mkldnn/test_fc_add_int8_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

@OpTestTool.skip_if_not_cpu()
class TestFCAddINT8OneDNNOp(OpTest):

def setUp(self):
self.op_type = "fc"
self._cpu_only = True
Expand All @@ -31,8 +30,9 @@ def setUp(self):
self.generate_data()
self.set_inputs()

y_scales_size = (self.bias_shape
if self.per_channel_quantize_weight else 1)
y_scales_size = (
self.bias_shape if self.per_channel_quantize_weight else 1
)

self.attrs = {
'use_mkldnn': True,
Expand Down Expand Up @@ -77,16 +77,17 @@ def quantize(self, tensor):
return scale, quantized

def generate_data(self):
self.x_float = np.random.random(
self.input_shape).astype("float32") * 10
self.x_float = np.random.random(self.input_shape).astype("float32") * 10
self.x_scale, self.x = self.quantize(self.x_float)

self.y_float = (np.random.random(self.weight_shape).astype("float32") *
10)
self.y_float = (
np.random.random(self.weight_shape).astype("float32") * 10
)
self.y_scale, self.y = self.quantize(self.y_float)

self.residual_float = np.random.random(
self.residual_shape).astype("float32") * 10
self.residual_float = (
np.random.random(self.residual_shape).astype("float32") * 10
)
self.residual_scale, self.residual = self.quantize(self.residual_float)

flatten_shape = [1, 1]
Expand All @@ -96,11 +97,11 @@ def generate_data(self):
else:
flatten_shape[1] *= self.input_shape[i]

self.out_float = np.dot(self.x_float.reshape(flatten_shape),
self.y_float)
self.out_float = np.dot(
self.x_float.reshape(flatten_shape), self.y_float
)
if self.use_bias:
self.bias = np.random.random(
self.bias_shape).astype("float32") * 10
self.bias = np.random.random(self.bias_shape).astype("float32") * 10
self.out_float += self.bias
# Dont' add residual
# self.out_float += self.residual_float
Expand All @@ -114,7 +115,6 @@ def test_check_output(self):


class TestFCINT8NoBiasOneDNNOp(TestFCAddINT8OneDNNOp):

def configure(self):
self.use_bias = False
self.force_fp32_output = False
Expand All @@ -130,7 +130,6 @@ def set_inputs(self):


class TestFCINT8ForceFP32OutputOneDNNOp(TestFCINT8NoBiasOneDNNOp):

def configure(self):
self.use_bias = False
self.force_fp32_output = True
Expand All @@ -139,7 +138,6 @@ def configure(self):


class TestFCINT8ForceFP32OutputPerChannelWeightOneDNNOp(TestFCAddINT8OneDNNOp):

def configure(self):
self.use_bias = True
self.force_fp32_output = True
Expand Down
12 changes: 6 additions & 6 deletions test/mkldnn/test_fc_add_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@ def fully_connected_naive(input, weights, bias_data, residual_data):


class MatrixGenerate:

def __init__(self, mb, ic, oc, h, w):
self.input = np.random.random((mb, ic * h * w)).astype("float32")
self.weights = np.random.random((ic * h * w, oc)).astype("float32")
self.residual = np.random.random((mb, oc)).astype("float32")


class TestFCAddMKLDNNOp(OpTest):

def create_data(self):
self.matrix = MatrixGenerate(1, 10, 15, 3, 3)
self.bias = np.random.random(15).astype("float32")
Expand All @@ -60,9 +58,12 @@ def setUp(self):
# self.attrs = {'use_mkldnn': self.use_mkldnn, 'fuse_residual_connection' : True}

self.outputs = {
'Out':
fully_connected_naive(self.matrix.input, self.matrix.weights,
self.bias, self.matrix.residual)
'Out': fully_connected_naive(
self.matrix.input,
self.matrix.weights,
self.bias,
self.matrix.residual,
)
}

def test_check_output(self):
Expand All @@ -76,7 +77,6 @@ def test_check_grad_no_weight(self):


class TestFCAddMKLDNNOp1(TestFCAddMKLDNNOp):

def create_data(self):
self.matrix = MatrixGenerate(2, 15, 48, 2, 2)
self.bias = np.random.random(48).astype("float32")
Expand Down

0 comments on commit a6016d5

Please sign in to comment.