1- import pyqtgraph as pg
21import numpy as np
32import os
43
54from matplotlib .backends .qt_compat import QtCore , QtWidgets
65from matplotlib .backends .backend_qt5agg import FigureCanvas , NavigationToolbar2QT as NavigationToolbar
7- from matplotlib .figure import Figure
8- from matplotlib .lines import Line2D
6+ from matplotlib import figure , lines , patches , path
97
108from .qt import QtWidgets , QtCore , \
119 PYSIDE2_LOADED , PYQT5_LOADED
3533 TF_LOADED = False
3634
3735
36+ class BarGraphWidget (QtWidgets .QWidget ):
37+ def __init__ (self ):
38+ super ().__init__ ()
39+ layout = QtWidgets .QVBoxLayout (self )
40+ self .canvas = FigureCanvas (figure .Figure (figsize = (5 , 3 )))
41+ layout .addWidget (self .canvas )
42+
43+ self .nbrCategories = 0
44+ self .offset_nullValue = .01
45+ self .ax = self .canvas .figure .subplots ()
46+ self .ax .set_xlim (0.0 , 1.0 )
47+ self .ax .set_ylim (0.0 , 1.0 )
48+
49+ font = {'family' : 'serif' ,
50+ 'color' : 'darkred' ,
51+ 'weight' : 'normal' ,
52+ 'fontsize' : 'smaller' ,
53+ }
54+
55+ self .initPlot (self .nbrCategories )
56+ self .updateValues (np .random .rand (self .nbrCategories ))
57+
58+ def initPlot (self , nbrBar :int ):
59+ self .nbrCategories = nbrBar
60+ if self .nbrCategories == 0 :
61+ bottom = 0
62+ top = 0
63+ left = 0
64+ right = self .offset_nullValue
65+ nrects = 0
66+
67+ else :
68+ bins = np .array ([float (i )/ self .nbrCategories for i in range (self .nbrCategories + 1 )])
69+
70+ bottom = bins [:- 1 ] + (.1 / self .nbrCategories )
71+ top = bins [1 :] - (.1 / self .nbrCategories )
72+ left = np .zeros (len (top ))
73+ right = left + self .offset_nullValue
74+ nrects = len (top )
75+
76+ nverts = nrects * (1 + 3 + 1 )
77+ self .verts = np .zeros ((nverts , 2 ))
78+ codes = np .full (nverts , path .Path .LINETO )
79+ codes [0 ::5 ] = path .Path .MOVETO
80+ codes [4 ::5 ] = path .Path .CLOSEPOLY
81+ self .verts [0 ::5 , 0 ] = left
82+ self .verts [0 ::5 , 1 ] = bottom
83+ self .verts [1 ::5 , 0 ] = left
84+ self .verts [1 ::5 , 1 ] = top
85+ self .verts [2 ::5 , 0 ] = right
86+ self .verts [2 ::5 , 1 ] = top
87+ self .verts [3 ::5 , 0 ] = right
88+ self .verts [3 ::5 , 1 ] = bottom
89+
90+ patch = None
91+
92+ barpath = path .Path (self .verts , codes )
93+ patch = patches .PathPatch (
94+ barpath , facecolor = 'green' , alpha = 0.5 ) #edgecolor='yellow',
95+ self .ax .add_patch (patch )
96+
97+ self .canvas .draw ()
98+
99+ def updateValues (self , values :np .ndarray ):
100+ self .verts [2 ::5 , 0 ] = values + self .offset_nullValue
101+ self .verts [3 ::5 , 0 ] = values + self .offset_nullValue
102+ self .canvas .draw ()
103+
104+ def changeCategories (self , categories ):
105+ self .initPlot (len (categories ))
106+ for cat in categories :
107+ self .ax .text (0.01 , 0.5 , r'$\cos(2 \pi t) \exp(-t)$' , fontdict = font )
108+
109+
38110class HandPlotWidget (QtWidgets .QWidget ):
39111 def __init__ (self ):
40112 super ().__init__ ()
41113 layout = QtWidgets .QVBoxLayout (self )
42- self .static_canvas = FigureCanvas (Figure (figsize = (5 , 3 )))
43- layout .addWidget (self .static_canvas )
114+ self .canvas = FigureCanvas (figure . Figure (figsize = (5 , 3 )))
115+ layout .addWidget (self .canvas )
44116
45- self .ax = self .static_canvas .figure .subplots ()
117+ self .ax = self .canvas .figure .subplots ()
46118 self .ax .set_xlim ([- 1. ,1. ])
47119 self .ax .set_ylim ([- 1. ,1. ])
48120 self .ax .set_aspect ('equal' )
49121
50122 self .fingerLines = [
51- Line2D ([], [], color = 'r' ),
52- Line2D ([], [], color = 'y' ),
53- Line2D ([], [], color = 'g' ),
54- Line2D ([], [], color = 'b' ),
55- Line2D ([], [], color = 'm' )]
123+ lines . Line2D ([], [], color = 'r' ),
124+ lines . Line2D ([], [], color = 'y' ),
125+ lines . Line2D ([], [], color = 'g' ),
126+ lines . Line2D ([], [], color = 'b' ),
127+ lines . Line2D ([], [], color = 'm' )]
56128
57129 for line in self .fingerLines :
58130 self .ax .add_line (line )
@@ -67,12 +139,12 @@ def plotHand(self, handKeypoints):
67139 np .insert (handKeypoints [:, 17 :21 ].T , 0 , handKeypoints [:,0 ], axis = 0 ).T ]
68140 for i ,line in enumerate (self .fingerLines ):
69141 line .set_data (data [i ][0 ], data [i ][1 ])
70- self .static_canvas .draw ()
142+ self .canvas .draw ()
71143
72144 def clear (self ):
73145 for line in self .fingerLines :
74146 line .set_data ([], [])
75- self .static_canvas .draw ()
147+ self .canvas .draw ()
76148
77149
78150class HandAnalysisWidget (QtWidgets .QGroupBox ):
@@ -88,13 +160,7 @@ def __init__(self, handID:int, showInput:bool=True):
88160 self .layout = QtWidgets .QGridLayout (self )
89161 self .setLayout (self .layout )
90162
91- self .classGraphWidget = pg .PlotWidget ()
92- self .classGraphWidget .setBackground ('w' )
93- self .classGraphWidget .setYRange (0.0 , 1.0 )
94- self .classGraphWidget .setTitle ('Predicted class: ' )
95-
96- self .outputGraph = pg .BarGraphItem (x = range (len (self .classOutputs )), height = [0 ]* len (self .classOutputs ), width = 0.6 , brush = 'k' )
97- self .classGraphWidget .addItem (self .outputGraph )
163+ self .classGraphWidget = BarGraphWidget ()
98164
99165 if self .showInput :
100166 self .handGraphWidget = HandPlotWidget ()
@@ -124,6 +190,8 @@ def drawHand(self, handKeypoints:np.ndarray, accuracy:float):
124190 self .updatePredictedClass (handKeypoints )
125191 if isHandData (handKeypoints ):
126192 self .handGraphWidget .plotHand (handKeypoints )
193+ else :
194+ self .handGraphWidget .clear ()
127195
128196
129197 def updatePredictedClass (self , keypoints :np .ndarray ):
@@ -133,7 +201,7 @@ def updatePredictedClass(self, keypoints:np.ndarray):
133201 keypoints (np.ndarray((3,21),float)): Coordinates x, y and the accuracy score for each 21 key points.
134202 '''
135203
136- prediction = [0 ] * len ( self .classOutputs )
204+ prediction = [0 for i in self .classOutputs ]
137205 title = 'Predicted class: None'
138206 if type (keypoints ) != type (None ):
139207 inputData = []
@@ -147,24 +215,20 @@ def updatePredictedClass(self, keypoints:np.ndarray):
147215 self .currentPrediction = self .classOutputs [np .argmax (prediction )]
148216 title = 'Predicted class: ' + self .currentPrediction
149217
150- self .outputGraph .setOpts (height = prediction )
151- self .classGraphWidget .setTitle (title )
218+ self .classGraphWidget .updateValues (np .array (prediction ))
152219
153220 def newModelLoaded (self , urlModel :str , classOutputs :list , handID :int ):
154221 if TF_LOADED :
155222 if urlModel == 'None' :
156223 self .modelClassifier = None
157224 self .classOutputs = []
158- self .outputGraph .setOpts (x = range (1 ,len (self .classOutputs )+ 1 ), height = [0 ]* len (self .classOutputs ))
159-
225+ self .classGraphWidget .changeCategories (self .classOutputs )
160226 else :
161227 if handID == self .handID :
162228 self .modelClassifier = tf .keras .models .load_model (urlModel )
163229 self .classOutputs = classOutputs
164- self .outputGraph .setOpts (x = range (1 ,len (self .classOutputs )+ 1 ), height = [0 ]* len (self .classOutputs ))
165- self .modelClassifier = tf .keras .models .load_model (urlModel )
166- self .classOutputs = classOutputs
167- self .outputGraph .setOpts (x = range (1 ,len (self .classOutputs )+ 1 ), height = [0 ]* len (self .classOutputs ))
230+ self .classGraphWidget .changeCategories (self .classOutputs )
231+ #self.classGraphWidget.setOpts(x=range(1,len(self.classOutputs)+1), height=[0]*len(self.classOutputs))
168232
169233 def getCurrentPrediction (self )-> str :
170234 return self .currentPrediction
0 commit comments