Skip to content

Commit 7e5a315

Browse files
committed
📦MatplotLib instead of pyqtgraph for classification visualization
1 parent cbde4b2 commit 7e5a315

2 files changed

Lines changed: 97 additions & 30 deletions

File tree

openhand_classifier/src/HandAnalysis.py

Lines changed: 93 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
import pyqtgraph as pg
21
import numpy as np
32
import os
43

54
from matplotlib.backends.qt_compat import QtCore, QtWidgets
65
from matplotlib.backends.backend_qt5agg import FigureCanvas, NavigationToolbar2QT as NavigationToolbar
7-
from matplotlib.figure import Figure
8-
from matplotlib.lines import Line2D
6+
from matplotlib import figure, lines, patches, path
97

108
from .qt import QtWidgets, QtCore, \
119
PYSIDE2_LOADED, PYQT5_LOADED
@@ -35,24 +33,98 @@
3533
TF_LOADED = False
3634

3735

36+
class BarGraphWidget(QtWidgets.QWidget):
37+
def __init__(self):
38+
super().__init__()
39+
layout = QtWidgets.QVBoxLayout(self)
40+
self.canvas = FigureCanvas(figure.Figure(figsize=(5, 3)))
41+
layout.addWidget(self.canvas)
42+
43+
self.nbrCategories = 0
44+
self.offset_nullValue = .01
45+
self.ax = self.canvas.figure.subplots()
46+
self.ax.set_xlim(0.0, 1.0)
47+
self.ax.set_ylim(0.0, 1.0)
48+
49+
font = {'family': 'serif',
50+
'color': 'darkred',
51+
'weight': 'normal',
52+
'fontsize': 'smaller',
53+
}
54+
55+
self.initPlot(self.nbrCategories)
56+
self.updateValues(np.random.rand(self.nbrCategories))
57+
58+
def initPlot(self, nbrBar:int):
59+
self.nbrCategories = nbrBar
60+
if self.nbrCategories == 0:
61+
bottom = 0
62+
top = 0
63+
left = 0
64+
right = self.offset_nullValue
65+
nrects = 0
66+
67+
else:
68+
bins = np.array([float(i)/self.nbrCategories for i in range(self.nbrCategories+1)])
69+
70+
bottom = bins[:-1] + (.1/self.nbrCategories)
71+
top = bins[1:] - (.1/self.nbrCategories)
72+
left = np.zeros(len(top))
73+
right = left + self.offset_nullValue
74+
nrects = len(top)
75+
76+
nverts = nrects * (1 + 3 + 1)
77+
self.verts = np.zeros((nverts, 2))
78+
codes = np.full(nverts, path.Path.LINETO)
79+
codes[0::5] = path.Path.MOVETO
80+
codes[4::5] = path.Path.CLOSEPOLY
81+
self.verts[0::5, 0] = left
82+
self.verts[0::5, 1] = bottom
83+
self.verts[1::5, 0] = left
84+
self.verts[1::5, 1] = top
85+
self.verts[2::5, 0] = right
86+
self.verts[2::5, 1] = top
87+
self.verts[3::5, 0] = right
88+
self.verts[3::5, 1] = bottom
89+
90+
patch = None
91+
92+
barpath = path.Path(self.verts, codes)
93+
patch = patches.PathPatch(
94+
barpath, facecolor='green', alpha=0.5) #edgecolor='yellow',
95+
self.ax.add_patch(patch)
96+
97+
self.canvas.draw()
98+
99+
def updateValues(self, values:np.ndarray):
100+
self.verts[2::5, 0] = values + self.offset_nullValue
101+
self.verts[3::5, 0] = values + self.offset_nullValue
102+
self.canvas.draw()
103+
104+
def changeCategories(self, categories):
105+
self.initPlot(len(categories))
106+
for cat in categories:
107+
self.ax.text(0.01, 0.5, r'$\cos(2 \pi t) \exp(-t)$', fontdict=font)
108+
109+
38110
class HandPlotWidget(QtWidgets.QWidget):
39111
def __init__(self):
40112
super().__init__()
41113
layout = QtWidgets.QVBoxLayout(self)
42-
self.static_canvas = FigureCanvas(Figure(figsize=(5, 3)))
43-
layout.addWidget(self.static_canvas)
114+
self.canvas = FigureCanvas(figure.Figure(figsize=(5, 3)))
115+
layout.addWidget(self.canvas)
44116

45-
self.ax = self.static_canvas.figure.subplots()
117+
self.ax = self.canvas.figure.subplots()
46118
self.ax.set_xlim([-1.,1.])
47119
self.ax.set_ylim([-1.,1.])
48120
self.ax.set_aspect('equal')
49121

50122
self.fingerLines = [
51-
Line2D([], [], color='r'),
52-
Line2D([], [], color='y'),
53-
Line2D([], [], color='g'),
54-
Line2D([], [], color='b'),
55-
Line2D([], [], color='m')]
123+
lines.Line2D([], [], color='r'),
124+
lines.Line2D([], [], color='y'),
125+
lines.Line2D([], [], color='g'),
126+
lines.Line2D([], [], color='b'),
127+
lines.Line2D([], [], color='m')]
56128

57129
for line in self.fingerLines:
58130
self.ax.add_line(line)
@@ -67,12 +139,12 @@ def plotHand(self, handKeypoints):
67139
np.insert(handKeypoints[:, 17:21].T, 0, handKeypoints[:,0], axis=0).T]
68140
for i,line in enumerate(self.fingerLines):
69141
line.set_data(data[i][0], data[i][1])
70-
self.static_canvas.draw()
142+
self.canvas.draw()
71143

72144
def clear(self):
73145
for line in self.fingerLines:
74146
line.set_data([], [])
75-
self.static_canvas.draw()
147+
self.canvas.draw()
76148

77149

78150
class HandAnalysisWidget(QtWidgets.QGroupBox):
@@ -88,13 +160,7 @@ def __init__(self, handID:int, showInput:bool=True):
88160
self.layout=QtWidgets.QGridLayout(self)
89161
self.setLayout(self.layout)
90162

91-
self.classGraphWidget = pg.PlotWidget()
92-
self.classGraphWidget.setBackground('w')
93-
self.classGraphWidget.setYRange(0.0, 1.0)
94-
self.classGraphWidget.setTitle('Predicted class: ')
95-
96-
self.outputGraph = pg.BarGraphItem(x=range(len(self.classOutputs)), height=[0]*len(self.classOutputs), width=0.6, brush='k')
97-
self.classGraphWidget.addItem(self.outputGraph)
163+
self.classGraphWidget = BarGraphWidget()
98164

99165
if self.showInput:
100166
self.handGraphWidget = HandPlotWidget()
@@ -124,6 +190,8 @@ def drawHand(self, handKeypoints:np.ndarray, accuracy:float):
124190
self.updatePredictedClass(handKeypoints)
125191
if isHandData(handKeypoints):
126192
self.handGraphWidget.plotHand(handKeypoints)
193+
else:
194+
self.handGraphWidget.clear()
127195

128196

129197
def updatePredictedClass(self, keypoints:np.ndarray):
@@ -133,7 +201,7 @@ def updatePredictedClass(self, keypoints:np.ndarray):
133201
keypoints (np.ndarray((3,21),float)): Coordinates x, y and the accuracy score for each 21 key points.
134202
'''
135203

136-
prediction = [0]*len(self.classOutputs)
204+
prediction = [0 for i in self.classOutputs]
137205
title = 'Predicted class: None'
138206
if type(keypoints) != type(None):
139207
inputData = []
@@ -147,24 +215,20 @@ def updatePredictedClass(self, keypoints:np.ndarray):
147215
self.currentPrediction = self.classOutputs[np.argmax(prediction)]
148216
title = 'Predicted class: ' + self.currentPrediction
149217

150-
self.outputGraph.setOpts(height=prediction)
151-
self.classGraphWidget.setTitle(title)
218+
self.classGraphWidget.updateValues(np.array(prediction))
152219

153220
def newModelLoaded(self, urlModel:str, classOutputs:list, handID:int):
154221
if TF_LOADED:
155222
if urlModel == 'None':
156223
self.modelClassifier = None
157224
self.classOutputs = []
158-
self.outputGraph.setOpts(x=range(1,len(self.classOutputs)+1), height=[0]*len(self.classOutputs))
159-
225+
self.classGraphWidget.changeCategories(self.classOutputs)
160226
else:
161227
if handID == self.handID:
162228
self.modelClassifier = tf.keras.models.load_model(urlModel)
163229
self.classOutputs = classOutputs
164-
self.outputGraph.setOpts(x=range(1,len(self.classOutputs)+1), height=[0]*len(self.classOutputs))
165-
self.modelClassifier = tf.keras.models.load_model(urlModel)
166-
self.classOutputs = classOutputs
167-
self.outputGraph.setOpts(x=range(1,len(self.classOutputs)+1), height=[0]*len(self.classOutputs))
230+
self.classGraphWidget.changeCategories(self.classOutputs)
231+
#self.classGraphWidget.setOpts(x=range(1,len(self.classOutputs)+1), height=[0]*len(self.classOutputs))
168232

169233
def getCurrentPrediction(self)->str:
170234
return self.currentPrediction

openhand_classifier/src/VideoInput.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,10 @@ def isRaisingHand(self):
320320
rightShoulder_x, rightShoulder_y, rightShoulder_a = poseKeypoints[2]
321321
leftShoulder_x, leftShoulder_y, leftShoulder_a = poseKeypoints[5]
322322

323-
shoulderSlope = (rightShoulder_y - leftShoulder_y) / (rightShoulder_x - leftShoulder_x)
323+
try:
324+
shoulderSlope = (rightShoulder_y - leftShoulder_y) / (rightShoulder_x - leftShoulder_x)
325+
except:
326+
shoulderSlope = 0.0
324327
shoulderOri = rightShoulder_y - shoulderSlope * rightShoulder_x
325328

326329
if leftHand_a > 0.1:

0 commit comments

Comments
 (0)