USYD - DATA 5207

DATA 5207

Lecture 2

presenting data to non-expert (visualization)

less technical knowledge

making data engage

convey the pattern

Data Graphics

3 consideration

what information want to communicate

who is the target audience

why design this feature relevant

Lecture 3

confounding factors: earning by height, may it occur by gender

select the topic this weekend

RMD template

Lecture 4

better R square, better job the model does. This means the better-fitting in model

maximize the variables in the model

not only use the technic thing to fit the data but adding the theoretical thing to increase the use in practice

observational data can not make the causal inference (confounding factor included)

model error

random errors (precision limitation - sample number)

systematic error (error in research design - non-sampling error)

difference between observed and actual

response, instrument, interviewer

sample design error

selection, frame

explore dataset graphs

variables

model choice

correlation plot all variables

variable selection methods - stepwise, lasso, or

Lecture 5

limitation of the LR is assuming the relationship is linear

logits?

ordinal logistic regression (agree, very agree, etc)

For the material in lab 5, the last image can be repainted as for each year, plot the importance of each variable (independent factor) into a single panel.

Lecture 6

fuzzyjoin is a function that similar operation in SQL

Research Plan Format

Format

Hide the R chunks, the template has the code to hide

key feature should be identified

why use the LR

Literature

theory from Literature

hypothesis is for testing? is that the previous section provided

literature: tells you, communicate the hypothesis you provide

inform the things you need to do

underpin the thing you want to explain

may error in the literature section, falsify the idea

Data

api missing

operation

Limitation

can be deleted

Lecture 7 Quiz week

data can from

consumer data

social media

AB testing (to decide the better version in different versions)

for instance, the color, and size of the button may sent to users, and the amount they click to decide the better version

census

individuals include surveys

web scraping

Lecture 8

survey

system error - no random

may younger people be more likely to respond to the phone (phone survey) - nonresponse bias

The census is not like the survey, due to its not doing the data sample, according to the entire residents in the country

random error refers to sampling

Assignment-1

  1. Economic: Q288, 50 (income)
    1. 1
  2. Occupation: Q281, 282, 283
  3. Education: Q275
    1. https://d1wqtxts1xzle7.cloudfront.net/49101438/18.01.053.20160304-libre.pdf?1474797472=&response-content-disposition=inline%3B+filename%3DEffect_of_Education_on_Quality_of_Life_a.pdf&Expires=1713509378&Signature=bTvJ0cklHa83ixDEhUTW02gYB4KW0iex7Mx6etlJqBNha-f0l-gvirWcVjlpbtaXdn5SsFoSsWtjeay-18z5De6i3e2wRtZvtx5cuzyJe2RLJHKYPPXrkiEORhb9c35JK-WjFa7T8c8OIQj5RxD11Gj3W7wCsC3jJwVOewTDYwkVBKXC1-7BjpWcbOSrkZnazJwulzVzLIERo0l6iO51LIqFi6wY8TSiTTdFGhiHctf9bu2Y7IapgVAwDLKXbpYTdXd3c4nVMPqQryYQ5iOjKVEmcCdMQwn0HUGe837Dn38-7ttCIbNASUOgpjEGQEjmNlznMsOW9jG~X9VHjw__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA
    2. https://www.sciencedirect.com/science/article/pii/S2214804314001153?via%3Dihub
  4. Societal Wellbeing: Q47 (health)
  5. Security: Q131-138, 52
  6. Social capital, trust: Q57-61
    1. https://link.springer.com/article/10.1007/s00148-007-0146-7 (neighbourhood only)

PCA to combine the multiple variables into one feature

Lecture 10

cable library in R

Lecture 11 - Causality

Lecture 12 - Journalism

datasplash platform

Final Project

the plots and tables can be included in the report

using the table to regression result (kable) function

USYD - INFO 5992

INFO 5992

Lecture 1

3 advantages and disadvantages of 5G wireless connections over 4G.

  1. More capacity for device connection in the meantime

  2. The higher transmission speed compared to the 4G

  3. The performance of latency shows the advantage

  4. The price of 5G IoT module almost doubled that 4G’s, showing more economic challenge, although showing the benefits on other areas.

  5. The coverage of 5G is weaker than 4G. Meanwhile, much time is still needed to establish an extensive area.

  6. Compatibility may become the upcoming problem in the future. Due to the out-of-date devices may not support the new standard(5G technology).

Q1: Yes.

  1. The benefits(speed, network capacity and lower latency etc.) of 5G are highly suitable for lots of organizations, especially for IT-based organizations.
  2. The development of 5G is unstoppable, the cost, coverage and other weaknesses will be solved in the near future. Therefore, end users, organizations and governments will embrace the network evolution (i.e. 5G).

Q2: The price will be lower.

  1. The cost of 5G network will be reduced, because of the more mature infrastructure and technology, which will be represented in the market price.
  2. The quantity of 5G users will increase gradually, which means that each 5G station cost will be separate for each user. It also will reduce the cost of 5G network use.

Q3: Yes.

  1. The 5G breaks many physical limitations. For instance, time latency. In the practice of clinics, one of the biggest limitations of remote operation is a delay in the network. The on-time network(5G) can lower the limitation and enlarge the feasibility.
  2. Also, the high speed of 5G can realize the large-capacity meeting, its all owed to the 5G.

Q4: No.

In my point of view, the network 5G is faster than 4G, not fundamentally changing the way of the network, but the development of the network.

Lecture 2

  1. Diffusion of innovation
    1. Innovation development process
  2. Technology adoption lifecycle model
  3. Dominate Design

Lecture 3

  1. Disruptive innovation
  2. Innovator’s dilemma

Lecture 4

API business model

API as product

API promoting means making the main business more popular

API enhancing means making the functionality better

May pro but not en, but it can not be en no pro

Lecture 5

Types of Crowdsourcing P14

Untitled

  1. put information and data to the platform
  2. can compare with others’ solutions (10% better etc.)
  3. creative things
  4. basic level of human intelligence

Lecture 6

user innovation: the innovations from the user or customer (company, B2B), due to the unfulfilled requirement.

Lecture 7

customer pivoting

solve the problem of the certain segment customers, and solve another problem of the same people

business pivoting

solve the different problems

Lecture 8

value proposition

Pytorch Tutorial 5

The reason use the NN is inner kernel of logistic regression is still linear, to avoid the linear relationship, the NN can use activation function, for instance ReLU.

In this case, we use ReLu as our activation function to predict the image, and it can be found that the accuracy is far better than LR, shows more abilities.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from os import path, mkdir
from random import randint

import torch
import numpy as np
import torchvision
from matplotlib import pyplot as plt
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.dataloader import DataLoader
import torch.nn.functional as F
import torch.nn as nn

dataset = MNIST(root="./data", download=True, transform=ToTensor())
test_dataset = MNIST(root='./data', train=False, transform=ToTensor())

def split_indices(n, rate):
# create number of validation set
n_val = int(n * rate)
# create shuffled index from 0-n, with no repeat
idxs = np.random.permutation(n)
# retuen (n_val,last) index and (first n_val) index
# i.e. training index and validation index
return idxs[n_val:], idxs[:n_val]

train_indices, val_indices = split_indices(len(dataset), 0.2)

batch_size = 100
train_sampler = SubsetRandomSampler(train_indices)
train_loder = DataLoader(dataset,
batch_size,
sampler=train_sampler)

val_sampler = SubsetRandomSampler(val_indices)
val_loder = DataLoader(dataset,
batch_size,
sampler=val_sampler)

input_size = 28 * 28
num_classes = 10

class MnistModel(nn.Module):

def __init__(self, in_size, hidden_size, out_size):
super().__init__()

self.linear1 = nn.Linear(in_size, hidden_size)

self.linear2 = nn.Linear(hidden_size, out_size)

def forward(self, xb):
# flatten
xb = xb.view(xb.size(0), -1)
# xb = xb.reshape(xb.size(0), -1)
return self.linear2(F.relu(self.linear1(xb)))

# for t in model.parameters():
# print(t.shape)

# for img, labels in train_loder:
# outputs = model(img)
# loss = F.cross_entropy(outputs, labels)
# break

def get_device():
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')

def to_device(data, device):
if isinstance(data, (list, tuple)):
return [to_device(x, device) for x in data]
return data.to(device, non_blocking=True)

# for img, label in train_loder:
# print(img.shape)
# img = to_device(img, device)
# print(img.device)
# break

class DeviceDataLoder():
def __init__(self, dl, device):
self.dl = dl
self.device = device

def __iter__(self):
# lazy load here
# instead of load data into device each time, instead, load each batch
for b in self.dl:
yield to_device(b, self.device)

def __len__(self):
return len(self.dl)

# use DeviceDataLoader as warpper
train_dl = DeviceDataLoder(train_loder, get_device())
valid_dl = DeviceDataLoder(val_loder, get_device())

def loss_batch(model, loss_func, xb, yb, opt=None, metric=None):
preds = model(xb)

loss = loss_func(preds, yb)

if opt is not None:
loss.backward()
opt.step()
opt.zero_grad()

metric_result = None
if metric is not None:
metric_result = metric(preds, yb)

return loss.item(), len(xb), metric_result

def evaluate(model, loss_func, valid_dl, metric=None):
with torch.no_grad():
results = [loss_batch(model, loss_func, xb, yb, metric=metric)
for xb, yb in valid_dl]

# separate the lists
loss, nums, metric = zip(*results)
total = np.sum(nums)
avg_loss = np.sum(np.multiply(loss, nums)) / total
avg_metric = None
if metric is not None:
avg_metric = np.sum(np.multiply(metric, nums)) / total
return avg_loss, total, avg_metric

def fit(epochs, lr, model, loss_func, train_dl, valid_dl, opt_fn=None, metric=None):
if opt_fn is None:
opt_fn = torch.optim.SGD
opt = opt_fn(model.parameters(), lr=lr)
loss_history = []
metric_history = []

for epoch in range(epochs):
for xb, yb in train_dl:
loss_batch(model, loss_func, xb, yb, opt)
result = evaluate(model, loss_func, valid_dl, metric)
val_loss, total, val_metric = result

loss_history.append(val_loss)
metric_history.append(val_metric)

if metric is not None:
print(f'Epoch [{epoch + 1}/{epochs}], Loss: {val_loss:.4f}, Metric: {val_metric:.4f}')
else:
print(f'Epoch [{epoch + 1}/{epochs}], Loss: {val_loss:.4f}')

return loss_history, metric_history

def accuracy(output, label):
_, preds = torch.max(output, dim=1)
return torch.sum(label == preds).item() / len(preds)

model = MnistModel(input_size, 32, num_classes)
to_device(model, get_device())

if path.exists('./tutorial5/mnist-logistic.pth'):
model.load_state_dict(torch.load('./tutorial5/mnist-logistic.pth'))

else:
loss_history, metric_history = fit(5, 0.5, model, F.cross_entropy,
train_dl,
valid_dl,
opt_fn=torch.optim.SGD,
metric=accuracy)
# it will save the weight and bias for this model
# new dir
mkdir('./tutorial5')
torch.save(model.state_dict(), './tutorial5/mnist-logistic.pth')

def prediction_img(img, model):
xb = img.unsqueeze(0)
yb = model(xb)
_, preds = torch.max(yb, dim=1)
return preds[0].item()

for i in range(10):
img, label = test_dataset[randint(0, len(test_dataset) - 1)]
img_np = np.array(img)
plt.imshow(img_np.squeeze(), cmap='gray')
plt.show()
print(prediction_img(img, model))

Pytorch Tutorial 3

simple linear regression with bulit in tools in pytorch

  1. generate prediction
  2. calculate the loss
  3. compute gradients of w and b
  4. adjust w and b
  5. reset gradients to zero

these 5 steps also respect to the loop in the next function

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn.functional as F

# temp, rainfall, humidity
# inputs = torch.tensor(np.random.uniform(0, 120, size=(15, 3)))
# the input and output here need to specify the dtype, otherwise, when torch generate the prediction,
# it will encounter the problem of dtype is not match
inputs = torch.tensor(np.array(
[[109.4144, 11.2775, 32.4521], [2.0002, 47.0248, 49.9469], [27.1528, 57.8907, 91.2076],
[44.8227, 71.6239, 64.0752], [66.0968, 92.5966, 94.0775], [59.6257, 76.9701, 92.1656],
[8.1551, 1.7426, 10.5297], [112.6036, 47.2793, 95.4221], [3.2212, 61.8274, 115.9187],
[35.0351, 110.6133, 66.6992], [8.8387, 21.8008, 50.0480], [68.7698, 59.9815, 12.0230],
[111.3881, 90.3050, 62.1327], [101.7462, 115.7447, 33.4925], [27.7659, 54.5803, 105.3599]], dtype='float32'))

# apples, oranges
# targets = torch.tensor(np.random.uniform(0, 50, size=(15, 2)))
targets = torch.tensor(np.array(
[[28.1090, 45.0061], [29.0839, 6.4205], [35.2633, 44.1196],
[29.5371, 6.8457], [7.4298, 36.1434], [6.6296, 47.1809],
[49.9750, 49.9321], [34.1796, 16.6732], [46.8875, 7.6084],
[23.0442, 42.2229], [29.7401, 13.4199], [3.0854, 21.4550],
[47.6801, 49.1518], [18.7320, 18.4418], [34.2725, 25.8721]], dtype='float32'))
# print(inputs)
# print(targets)

# TensorDataset will creat the structure of pairing (input and target) accordingly
train_ds = TensorDataset(inputs, targets)

batch_size = 5
train_dl = DataLoader(train_ds, batch_size, shuffle=True)

# Each batch size is 5, and the data are shuffled
# and is still can contain the pair of data, the structure won't be shuffled
# for xb, yb in train_dl:
# print("batch:")
# print(xb)
# print(yb)

# specify the input and output feature number
model = nn.Linear(3, 2)
# the weight and bias will be initialed automatically, and the parameter of requires_grad will be set as True
# print(model.weight)
# print(model.bias)
# print(list(model.parameters()))

# preds = model(inputs)
# print(preds)

loss_fn = F.mse_loss
loss = loss_fn(model(inputs), targets)
# print(loss)

opt = torch.optim.SGD(model.parameters(), lr=1e-5)

# 1 generate prediction
# 2 calculate the loss
# 3 compute gradients of w and b
# 4 adjust w and b
# 5 reset gradients to zero
# these 5 steps also respect to the loop in the next function

def fit(num_epochs, model, loss_fn, opt):
# training interation
for epoch in range(num_epochs):
# batches in each interation
for xb, yb in train_dl:
pred = model(xb)
loss = loss_fn(pred, yb)
loss.backward()
opt.step()
opt.zero_grad()
if (epoch+1) % 10 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

fit(100, model, loss_fn, opt)

Pytorch Tutorial 4

  1. load dataset
    1. transform the data into tensor
  2. split the dataset into training, testing, validation datasets
    1. define the function of indices shuffle (the dataset are ordered, if missing apply the shuffle, the individual dataset may only contains one label)
    2. create sampler and loader
  3. customise the MnistModel function
  4. define loss_batch
    1. calculate loss in current batch
  5. define evaluate
    1. calculate average loss in batches
  6. define accuracy
    1. also called metric to shows the accuracy
  7. create fit function
    1. epoch loop
      1. train loop
        1. loss_batch — for train
      2. evaluate result
      3. print result
  8. call fit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from os import path
from random import randint

import torch
import torchvision
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn as nn
import torch.nn.functional as F

# transoforms used to transform the MNIST dataset into tensor in order to torch can work with
import torchvision.transforms as transforms

# here the datasets in original format, can not be understood by torch
datasets = MNIST(root='./data', download=True)
# print(len(datasets))

test_dataset = MNIST(root='./data', train=False, transform=transforms.ToTensor())
# print(len(test_dataset))

# img, label = datasets[0]
# plt.imshow(img, cmap='gray')
# plt.show()

# print(label)

# here the dataset is already transformed into tensor
dataset = MNIST(root='./data', download=True, transform=transforms.ToTensor())

# the shape here is 1,28,28, color, height, weight
# img_tensor, label = dataset[0]
# print(img_tensor.shape, label)

# print(img_tensor[:, 10:15, 10:15])
# print(torch.max(img_tensor), torch.min(img_tensor))
# plt.imshow(img_tensor[0, 10:15, 10:15], cmap='gray')
# plt.show()

def split_indices(n, rate):
# create number of validation set
n_val = int(n * rate)
# create shuffled index from 0-n, with no repeat
idxs = np.random.permutation(n)
# retuen (n_val,last) index and (first n_val) index
# i.e. training index and validation index
return idxs[n_val:], idxs[:n_val]

train_indices, val_indices = split_indices(len(dataset), 0.2)
# print(len(train_indices), len(val_indices))

# the sampler here is randomly select the indices from list with number of batch_size
# the reason for this is lower down the training time and computation
# and utilize multiple epoch to train the model, if not, the training will deal with whole data set,
# that will occupy too much memory space and make too much pressure to computational resources.
# in this case, the training process will transfer to smaller chucks
batch_size = 100
train_sampler = SubsetRandomSampler(train_indices)
train_loder = DataLoader(dataset,
batch_size,
sampler=train_sampler)

val_sampler = SubsetRandomSampler(val_indices)
val_loder = DataLoader(dataset,
batch_size,
sampler=val_sampler)

input_size = 28 * 28
num_classes = 10

# model = nn.Linear(input_size, num_classes)

# print(model.weight.shape)
# print(model.bias.shape)
#
# print(model.weight)
# print(model.bias)

# for img, label in train_loder:
# print(img.shape)
# print(label)
# # there is a error, the shape of image is 1*28*28, but the received input shape was set 784
# # so, the customized model are needed.
# print(model(img))
# break

class MnistModel(nn.Module):
def __init__(self):
super().__init__()
# define the input and output for linear
self.linear = nn.Linear(input_size, num_classes)

def forward(self, xb):
# reshape -1 here avoid the hard code, it will calculate the first dimension number
xb = xb.reshape(-1, input_size)
# pass the batch data to linear layer
out = self.linear(xb)
return out

model = MnistModel()

# the weight and bias are in the linear(model.linear.weight), instead of the model above(model.weight)
# print(model.linear.weight.shape)
# print(model.linear.bias.shape)
#
# print(model.linear.weight)
# print(model.linear.bias)

def accuracy(l1, l2):
return torch.sum(l1 == l2).item() / len(l2)

Log plot presentation

Untitled

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# for img, label in train_loder:
# the img pass in the model shape is 100,1,28,28
# the output shape is 100,10
# which reaches what we expected (represent the 0-9 digital number)
# here the softmax can be introduced to show the possibility with each number correspondingly
# possibility = e^y_i / sum(e^y_i)
# outputs = model(img)
# the second parameter here indicates the dim index need to be applied
# so 0 means the column direction, and 1 for row direction for 2D matrix
# probs = F.softmax(outputs, 1)
# print(probs.shape)
# so now the probs shape is 100,10, but each value each row represent possibility(0-1), and sum of each row is 1
# print(outputs.shape)
# print(outputs[0])
# max_probs, predicted_labels = torch.max(probs, 1)
# print(accuracy(predicted_labels, label))

# now, we need to define the loss function
# here the cross entropy is most suitable for logistic regression
# i.e.
# the true label 9 is represented vector of [0,0,0,0,0,0,0,0,0,1]
# the predict vector [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9] for instance
# and the cross entropy is -ln(y*y_pred) i.e. -ln(1*0.9) = 0.10, which is low

# but, when the prediction is poor
# the true label 1 is represented vector of [0,1,0,0,0,0,0,0,0,0]
# the predict vector [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9] for instance
# and the cross entropy is -ln(y*y_pred) i.e. -ln(1*0.2) = 1.6, which is high

# in the cross entropy, we only consider the right label, and ignore the other, because their vector is 0

# so when low possibility for the correct number the cross entropy(loss) is high, v.v

# define the loss function for current batch
# loss = F.cross_entropy(outputs, label)

# the equation here is -e.pow(right prediction possibility)=loss
# so the right possibility is e.pow(-loss)
# learn_rate = 0.001
# optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)
# optimizer.step()
# break

def loss_batch(model, loss_func, xb, yb, opt=None, metric=None):
preds = model(xb)
loss = loss_func(preds, yb)

if opt is not None:
loss.backward()
opt.step()
opt.zero_grad()

# metric is used for model evaluation
metric_result = None
if metric is not None:
metric_result = metric(preds, yb)

return loss.item(), len(xb), metric_result

def evaluate(model, loss_func, valid_dl, metric=None):
with torch.no_grad():
results = [loss_batch(model, loss_func, xb, yb, metric=metric)
for xb, yb in valid_dl]

# separate the lists
loss, nums, metric = zip(*results)
total = np.sum(nums)
avg_loss = np.sum(np.multiply(loss, nums)) / total
avg_metric = None
if metric is not None:
avg_metric = np.sum(np.multiply(metric, nums)) / total
return avg_loss, total, avg_metric

def accuracy(output, label):
_, preds = torch.max(output, dim=1)
return torch.sum(label == preds).item() / len(preds)

# avg_loss, total, val_acc = evaluate(model, F.cross_entropy, val_loder, metric=accuracy)
# print("Loss: {:.4f}, total:{:.4f}, Accuracy: {:.4f}".format(avg_loss, total, val_acc))

def fit(epochs, model, loss_fn, opt, train_dl, valid_dl, metric=None):
for epoch in range(epochs):
for xb, yb in train_dl:
loss, _, _ = loss_batch(model, loss_fn, xb, yb, opt, metric=metric)

result = evaluate(model, loss_fn, valid_dl, metric=metric)
val_loss, total, val_metric = result

if metric is None:
print("Epoch [{}/{}], total:{:.4f}, Loss: {:.4f}"
.format(epoch + 1, epochs, total, val_loss, val_metric))
else:
print("Epoch [{}/{}], total:{:.4f}, Loss: {:.4f}, {}: {:.4f}"
.format(epoch + 1, epochs, total, val_loss, metric.__name__, val_metric))

model = MnistModel()

# if path is not blank
if path.exists('mnist-logistic.pth'):
model.load_state_dict(torch.load('mnist-logistic.pth'))

else:
fit(5,
model,
F.cross_entropy,
torch.optim.SGD(model.parameters(), lr=0.001),
train_loder,
val_loder,
metric=accuracy)
# it will save the weight and bias for this model
torch.save(model.state_dict(), 'mnist-logistic.pth')

# read the saved model into instance
# model2 = MnistModel()
# model2.load_state_dict(torch.load('mnist-logistic.pth'))
# model2.state_dict()

def prediction_img(img, model):
xb = img.unsqueeze(0)
yb = model(xb)
_, preds = torch.max(yb, dim=1)
return preds[0].item()

for i in range(10):
img, label = test_dataset[randint(0, len(test_dataset) - 1)]
img_np = np.array(img)
plt.imshow(img_np.squeeze(), cmap='gray')
plt.show()
print(prediction_img(img, model))

Question

  1. when import test_dataset missing the parameter of transform, made the validation section encounter the problem of img no squeeze parameter
  2. zip(*results), used for unpack the tuples, and pass into multiple instances
  3. avg_loss = np.sum(np.multiply(loss, nums)) / total the reason use multiply here is for last batch number, is might not equals to previous number

Pytorch Tutorial 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import numpy as np

t1 = torch.tensor(4.)
print(t1)
print(t1.dtype)

t2 = torch.tensor([1., 2, 3, 4])
print(t2)
print(t2.dtype)
# in this case the all data will be transformed to same data type
# [1., 2., 3., 4.]

t3 = torch.tensor([1., 2, 3, 4])
print(t3)
print(t3.dtype)

t4 = torch.tensor([[1, 2], [1., 4], [4, 3], [5, 6]])
print(t4)
print(t4.dtype)

print(t1.shape)
print(t2.shape)
print(t3.shape)
print(t4.shape)

# ---
x = torch.tensor(3., requires_grad=True)
w = torch.tensor(4., requires_grad=True)
b = torch.tensor(5., requires_grad=True)

y = w * x + b
print(y)
y.backward()

print(x.grad)
print(w.grad)
print(b.grad)

# convert numpy to torch
x = np.array([[1, 2], [2, 4]])

# use shared memory space, not copy
y = torch.from_numpy(x)

# copy data
y = torch.tensor(x)

print(y)
print(y.dtype)

# convert torch to numpy
z = y.numpy()
print(z)

Pytorch Tutorial 2

simple linear regression with auto gradient method in pytorch

  1. @ means inner dot
  2. .t() means transpose matrix
  3. .numel() means number of element in matrix
  4. with torch.no_grad() means code insider this block will not track gradients to save memory and computation time
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import numpy as np

inputs = np.array([[0, 0, 3],
[0, 1, 9],
[1, 0, 8],
[1, 1, 28]], dtype='float32')

outputs = np.array([[0, 1],
[9, 4],
[7, 3],
[6, 7]], dtype='float32')

inputs = torch.from_numpy(inputs)
outputs = torch.from_numpy(outputs)

w = torch.randn(2, 3, requires_grad=True)
b = torch.randn(2, requires_grad=True)

# print(b)

def model(x):
# the b is the vector, when the matrix plus b, the b will be copy bunch of data to make it as the matrix
return x @ w.t() + b

def mse(t1, t2):
return torch.sum((t1 - t2) ** 2) / t1.numel()

learning_rate = 1e-5
for t in range(500):
y_pred = model(inputs)
loss = mse(y_pred, outputs)
loss.backward()
with torch.no_grad():
w -= learning_rate * w.grad
b -= learning_rate * b.grad
w.grad.zero_()
b.grad.zero_()
print(loss.item())

Shiro

shiro实现权限

其他框架:

sprintSecurity(还未研究,据说类似)

配置步骤:

  1. 导入maven
  2. 添加ShiroConfig
  3. 创建UserRealm

注意:

  • 在过滤器创建的时候是 LinkedHashMap 千万注意, 不然通配符不会匹配,会被覆盖掉
  • 过滤器创建时可设置未授权跳转页面,不适用于分离项目
  • 过滤工厂创建的时候可以放入自定义拦截器,用于项目自己的业务
  • 可以在subject.login()之前,通过对session设置来修改登录过期时间

核心:

  • Subject 获取当前对象
  • token可用username,password生成
  • 可使用Md5Hash进行加密
  • 在调用login后通过捕获异常来区别不同的登录情况
1
2
3
4
5
6
7
8
9
10
<dependency>
<groupId>org.apache.shiro</groupId>
<artifactId>shiro-core</artifactId>
<version>1.7.1</version>
</dependency>
<dependency>
<groupId>org.apache.shiro</groupId>
<artifactId>shiro-spring</artifactId>
<version>1.4.0</version>
</dependency>
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package com.jcDemo.shiro;

import com.jcDemo.interceptor.MyFormAuthenticationFilter;
import org.apache.shiro.mgt.DefaultSecurityManager;
import org.apache.shiro.mgt.SecurityManager;
import org.apache.shiro.spring.security.interceptor.AuthorizationAttributeSourceAdvisor;
import org.apache.shiro.spring.web.ShiroFilterFactoryBean;
import org.apache.shiro.web.mgt.DefaultWebSecurityManager;
import org.springframework.aop.framework.autoproxy.DefaultAdvisorAutoProxyCreator;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import javax.servlet.Filter;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;

/**
*@author:gao_quansui
*@user:ASUS
*@date:2022/9/21- 10:57
*@projectName:jc_demo
*/

@Configuration
public class ShiroConfig {

@Bean
public UserRealm userRealm() {
return new UserRealm();
}

@Bean(name = "filterShiroFilterRegistrationBean")
public ShiroFilterFactoryBean getShiroFilterFactoryBean(@Qualifier("SecurityManager") DefaultWebSecurityManager defaultWebSecurityManager) {
//1.创建过滤工厂
ShiroFilterFactoryBean bean = new ShiroFilterFactoryBean();

Map<String, Filter> filters = new LinkedHashMap<>();
filters.put("MyFormAuthenticationFilter", new MyFormAuthenticationFilter());
bean.setFilters(filters);
//2.设置安全管理器
bean.setSecurityManager(defaultWebSecurityManager);

//3.配置未授权跳转页面
// bean.setLoginUrl("/test");
// bean.setLoginUrl(null);

//4.设置filter
Map<String, String> filterMap = new LinkedHashMap<>();
filterMap.put("/api/user/login", "anon");

filterMap.put("/api/**", "MyFormAuthenticationFilter");

bean.setFilterChainDefinitionMap(filterMap);

return bean;
}

@Bean(name = "SecurityManager")
public DefaultWebSecurityManager getDefaultWebSecurityManager(@Qualifier("userRealm") UserRealm userRealm) {
DefaultWebSecurityManager securityManager = new DefaultWebSecurityManager();
securityManager.setRealm(userRealm);
return securityManager;
}

//开启shiro注解
@Bean
public AuthorizationAttributeSourceAdvisor authorizationAttributeSourceAdvisor(SecurityManager securityManager) {
AuthorizationAttributeSourceAdvisor advisor = new AuthorizationAttributeSourceAdvisor();
advisor.setSecurityManager(securityManager);
return advisor;
}

//开启aop注解支持
@Bean
public DefaultAdvisorAutoProxyCreator defaultAdvisorAutoProxyCreator() {
DefaultAdvisorAutoProxyCreator defaultAAP = new DefaultAdvisorAutoProxyCreator();
defaultAAP.setProxyTargetClass(true);
return defaultAAP;
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
package com.jcDemo.shiro;

import com.jcDemo.entity.entities.User;
import com.jcDemo.entity.vo.UidUsernameName;
import com.jcDemo.service.user.RoleService;
import com.jcDemo.service.user.UserService;
import lombok.extern.slf4j.Slf4j;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.authc.*;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.realm.AuthorizingRealm;
import org.apache.shiro.subject.PrincipalCollection;
import org.apache.shiro.subject.Subject;
import org.springframework.beans.factory.annotation.Autowired;

import java.util.List;

/**
*@author:gao_quansui
*@user:ASUS
*@date:2022/9/21- 10:07
*@projectName:jc_demo
*/
@Slf4j
public class UserRealm extends AuthorizingRealm {

@Autowired
UserService userService;

@Autowired
RoleService roleService;

//授权
@Override
protected AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principalCollection) {
log.info("执行====================授权");
SimpleAuthorizationInfo info = new SimpleAuthorizationInfo();

// Object data = roleService.getUserRoleById(1).getData();
// info.addStringPermission("");

Subject subject = SecurityUtils.getSubject();

User user = (User) subject.getPrincipal(); //下面方法传上来的user对象

//获取当前用户的权限数组
List<UidUsernameName> uidUsernameName = (List<UidUsernameName>) roleService.getUserRoleById(user.getId());

//遍历添加权限
uidUsernameName.forEach(e -> {
info.addStringPermission(e.getRoleName());
log.info("username:{}-------roleName:{}", e.getUsername(), e.getRoleName());
});

return info;
}

//认证
@Override
protected AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken token) throws AuthenticationException {

UsernamePasswordToken userToken = (UsernamePasswordToken) token;

User user = userService.getUserByName(userToken.getUsername());

if (user == null) {
return null;
}

log.info("认证");

Subject subject = SecurityUtils.getSubject();
// subject.isPermitted("123");
// subject.hasRole("authc");

// if(!userToken.getUsername().equals(user.getUsername())){
// return null; //会在controller中捕获
// }
//
return new SimpleAuthenticationInfo(user, user.getPassword(), ""); //验证密码
// return null;
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
package com.jcDemo.interceptor;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.jcDemo.entity.res.Result;
import com.jcDemo.entity.res.ResultCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.shiro.web.filter.authc.FormAuthenticationFilter;
import org.apache.shiro.web.servlet.ShiroHttpServletRequest;
import org.springframework.util.StringUtils;

import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.PrintWriter;

/**
*@author:gao_quansui
*@user:ASUS
*@date:2022/9/28- 10:07
*@projectName:jc_demo
*/
@Slf4j
public class MyFormAuthenticationFilter extends FormAuthenticationFilter {
@Override
protected boolean onAccessDenied(ServletRequest request, ServletResponse response) throws Exception {
log.info("被shiro拦截啦=================================");

PrintWriter out = null;
try {
HttpServletResponse res = (HttpServletResponse) response;
response.setCharacterEncoding("UTF-8");
response.setContentType("application/json; charset=utf-8");
out = response.getWriter();
if (res.getStatus() == HttpServletResponse.SC_UNAUTHORIZED) {
out.println(JSON.toJSONString(new Result(ResultCode.UNAUTHORISE)));
} else {
if (StringUtils.isEmpty(((ShiroHttpServletRequest) request).getHeader("Authorization"))) {
log.info("未登录");
out.write(JSONObject.toJSONString(new Result(ResultCode.UNAUTHENTICATED, "")));
} else {
log.info("session已过期");
out.write(JSONObject.toJSONString(new Result(ResultCode.EXPIREDSESSION, "")));
}
}
} catch (IOException e) {
log.info("session异常");
} finally {
if (out != null) {
out.close();
}
}
return Boolean.FALSE;
}

@Override
protected boolean isAccessAllowed(ServletRequest request, ServletResponse response, Object mappedValue) {
return super.isAccessAllowed(request, response, mappedValue);
}
}

整合swagger

整合swagger

Swagger简介

(12条消息) swagger使用教程——快速使用swagger_其实不会敲代码的博客-CSDN博客_swagger使用

Step

  1. maven依赖
  2. 配置SwaggerConfig(作者,项目名,邮件等)
  3. 启动后访问路径Swagger UI

使用

  • 实体类中
    • @ApiModel类描述
    • @ApiModelProperty类中字段描述
  • 控制器中
    • @Api控制器描述
    • @ApiOperation接口描述
1
2
3
4
5
6
7
8
9
10
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>

<dependency>
<groupId>io.springfox</groupId>
<artifactId>springfox-boot-starter</artifactId>
<version>3.0.0</version>
</dependency>
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
package com.jcDemo.config;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import springfox.documentation.builders.ApiInfoBuilder;
import springfox.documentation.builders.PathSelectors;
import springfox.documentation.builders.RequestHandlerSelectors;
import springfox.documentation.service.ApiInfo;
import springfox.documentation.service.Contact;
import springfox.documentation.spi.DocumentationType;
import springfox.documentation.spring.web.plugins.Docket;
import springfox.documentation.swagger2.annotations.EnableSwagger2;

/**
*@author:gao_quansui
*@user:ASUS
*@date:2022/9/22- 10:08
*@projectName:jc_demo
*/
@Configuration
public class SwaggerConfig {
@Bean
public Docket docket() {
return new Docket(DocumentationType.OAS_30).apiInfo(
new ApiInfoBuilder()
.contact(new Contact("gqs", "", ""))
.title("JC_Demo")
.build()
);
}
}

整合redis缓存

整合redis缓存

分布式锁的实现和解析

(12条消息) 分布式锁之Redis实现_kuan_sun的博客-CSDN博客_redis锁的实现

整合步骤:

  1. 下载redis,解压,修改配置文件(redis.windows.conf)
  2. 导入redis启动依赖
  3. 创建redis配置类 (主要是对于kv的序列化和反序列化)
  4. 启动redis-server.exe (redis服务)(同时也可以通过指定配置文件进行启动—集群)
  5. 启动redis-cli.exe (操作客户端)

注意:

  • HashMap不能设置过期时间!!!
  • 使用device:No_001:Name 的方式来存放K
  • 在存放value(对象)时 需要导入FastJson来进行操作
  • redis操作都在redisTemplate中,有很多操作,需要熟悉

引入缓存可能导致的问题(目前能想到的)

  1. 在修改后,需要对缓存进行操作,不然缓存中的数据有误
    1. 在修改时判断是否有缓存 ?
      1. 有:改缓存,利用缓存来修改持久层的(可以慢慢操作,在缓存失效之前操作完)
      2. 没有:修改持久层的,存缓存(有修改,一定马上会用到)
  2. 删除两边都得删
  3. 新增时添加缓存(新增也一定马上会用到)
  4. 查找时:如果缓存有,是否需要更新过期时间?
  5. 分页需要做缓存吗,怎么做?

可以改进的方向:

  1. redis集群
    1. 利用redis中String来做
      1. setnx device:NO_001_Lock 1 ex time

        Untitled

      2. 0则再用,重试请求;1进行业务操作

      3. 设置过期时间,避免设置锁后宕机死锁 时间需合适

        1. 太短:还没操作完就释放了
        2. 太长:已经操作完了 还锁着
      4. 解决:线程守护

        1. 设置一定时长
        2. 例设置10s 8s时判断是否还在执行? 延长 : 不延长
1
2
3
4
5
<!-- redis依赖-->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
package com.jcDemo.config;

import com.alibaba.fastjson2.support.spring.data.redis.FastJsonRedisSerializer;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.*;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.data.redis.serializer.GenericToStringSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;

/**
*@author:gao_quansui
*@user:ASUS
*@date:2022/9/28- 13:41
*@projectName:jc_demo
*/
@Configuration
public class RedisConfig {

@Bean
public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory factory) {
FastJsonRedisSerializer<Object> objectFastJsonRedisSerializer = new FastJsonRedisSerializer<>(Object.class);
StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
// 自定义的RedisTemplate
RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
// 设置key的序列化方法
redisTemplate.setKeySerializer(new StringRedisSerializer());
// 核心的设置 1.2.36版本自动提供
redisTemplate.setValueSerializer(objectFastJsonRedisSerializer);
// 对hash的序列化操作设置
redisTemplate.setHashKeySerializer(stringRedisSerializer);
redisTemplate.setHashValueSerializer(objectFastJsonRedisSerializer);
// 注册到工程类
redisTemplate.setConnectionFactory(factory);
return redisTemplate;
}

@Bean
public ValueOperations<String, Object> valueOperations(RedisTemplate<String, Object> redisTemplate) {
return redisTemplate.opsForValue();
}

@Bean
public HashOperations<String, String, Object> hashOperations(RedisTemplate<String, Object> redisTemplate) {
return redisTemplate.opsForHash();
}

@Bean
public ListOperations<String, Object> listOperations(RedisTemplate<String, Object> redisTemplate) {
return redisTemplate.opsForList();
}

@Bean
public SetOperations<String, Object> setOperations(RedisTemplate<String, Object> redisTemplate) {
return redisTemplate.opsForSet();
}

@Bean
public ZSetOperations<String, Object> zSetOperations(RedisTemplate<String, Object> redisTemplate) {
return redisTemplate.opsForZSet();
}
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package com.jcDemo.controller;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.Page;
import com.github.pagehelper.PageHelper;
import com.jcDemo.commom.CommonException;
import com.jcDemo.entity.entities.Device;
import com.jcDemo.entity.res.PageResult;
import com.jcDemo.entity.res.Result;
import com.jcDemo.entity.res.ResultCode;
import com.jcDemo.service.device.DeviceService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiOperation;
import lombok.extern.slf4j.Slf4j;
import lombok.var;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.*;

import javax.annotation.Resource;
import java.text.ParseException;
import java.util.List;
import java.util.concurrent.TimeUnit;

/**
*@author:gao_quansui
*@user:ASUS
*@date:2022/9/23- 10:51
*@projectName:jc_demo
*/
@Slf4j
@Api(tags = "设备管理")
@RestController
@RequestMapping("/api/device")
public class DeviceController {

@Autowired
DeviceService deviceService;

@Resource
RedisTemplate redisTemplate;

@ApiOperation("查找所有设备")
@GetMapping("/getDevices")
public Result getDevices(@RequestParam(defaultValue = "1") int pageNum,
@RequestParam(defaultValue = "10") int pageSize) {
PageResult pr = null;
try {
Page page = PageHelper.startPage(pageNum, pageSize);
List<Device> devices = deviceService.getDevices();

pr = new PageResult(page.getTotal(), devices);
log.info("pageNum:{},pageSize:{}", pageNum, pageSize);
} catch (CommonException e) {
e.printStackTrace();
return new Result(ResultCode.EMPTY);
}
return new Result(ResultCode.SUCCESS, pr);
}

//设备编号查询
@ApiOperation("查找设备ByNo")
@GetMapping("/getDeviceByNo/{no}")
public Result getDeviceByNo(@PathVariable("no") String no) throws CommonException {
Device device = null;
try {
if (redisTemplate.hasKey("devices:" + no)) {
log.info("重置时间->devices:{}", no);
redisTemplate.expire("devices:" + no, 300, TimeUnit.SECONDS);
log.info("从redis取出来的devices:{}", no);
device = JSON.parseObject(String.valueOf(redisTemplate.opsForValue().get("devices:" + no)), Device.class);
} else {
log.info("从mysql取出来的{}", no);
//mysql没有抛异常 下面捕获返回空
device = deviceService.getDeviceByNo(no);
//取出来存入缓存
redisTemplate.opsForValue().set("devices:" + no, JSON.toJSONString(device), 300, TimeUnit.SECONDS);
}
return new Result(ResultCode.SUCCESS, device);
} catch (CommonException e) {
e.printStackTrace();
return new Result(ResultCode.EMPTY);
}
}

//查询
@ApiOperation("头部查找")
@GetMapping("/searchDevice")
public Result searchDevice(@RequestBody Device device) {
List<Device> devices;
try {
devices = deviceService.searchDevice(device);
return new Result(ResultCode.SUCCESS, devices);
} catch (CommonException e) {
e.printStackTrace();
return new Result(ResultCode.EMPTY);
}
}

@ApiOperation("更新设备")
@PostMapping("/updateDevice")
public Result updateDevice(@RequestBody Device device) throws ParseException {
if (deviceService.updateDevice(device) == 1) {
//更新成功后判断是否有缓存 有就换
if (redisTemplate.hasKey("devices:" + device.getDeviceNo())) {
Boolean delete = redisTemplate.delete("devices:" + device.getDeviceNo());
if (delete) {
redisTemplate.opsForValue().set("devices:" + device.getDeviceNo(), JSON.toJSONString(device), 300, TimeUnit.SECONDS);
} else {
log.warn("缓存更新失败,请检查!");
}
} else {
redisTemplate.opsForValue().set("devices:" + device.getDeviceNo(), JSON.toJSONString(device), 300, TimeUnit.SECONDS);
}
return new Result(ResultCode.SUCCESS);
} else {
return new Result(ResultCode.ERROR);
}
}

@ApiOperation("删除设备")
@DeleteMapping("/delete/{no}")
public Result deleteDeviceById(@PathVariable("no") String no) {
if (deviceService.deleteDeviceByNo(no) == 1) {
//删除成功后处理redis
if (redisTemplate.delete("devices:" + no)) {
log.warn("缓存删除成功");
} else {
//缓存删除失败 or 从select里面取的数据,没有进redis 直接删除也会打印log
log.warn("缓存删除失败,请检查!");
}
return new Result(ResultCode.SUCCESS);
} else {
return new Result(ResultCode.ERROR);
}
}

@ApiOperation("新增设备")
@PostMapping("/insert")
public Result insertDevice(@RequestBody Device device) throws ParseException {
if (deviceService.insertDevice(device) == 1) {
redisTemplate.opsForValue().set("devices:" + device.getDeviceNo(), JSON.toJSONString(device), 300, TimeUnit.SECONDS);
return new Result(ResultCode.SUCCESS);
} else {
return new Result(ResultCode.ERROR);
}
}
}