2015-11-20 27 views
6

我想基础上,这个问题的答案创建一个可滚动的multiplot:使用ax.plot()正确更新 Creating a scrollable multiplot with python's pylab如何更新滚动,matplotlib艺术家的multiplot

线创建的,但我不能弄清楚如何更新使用xvlines()fill_between()创建的艺术家。

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
import matplotlib.gridspec as gridspec 
from matplotlib.widgets import Slider 

#create dataframes 
dfs={} 
for x in range(100): 
    col1=np.random.normal(10,0.5,30) 
    col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30] 
    col3=np.random.randint(4,size=30) 
    dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3}) 

#create figure,axis,subplot 
fig = plt.figure() 
gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1) 
ax = plt.subplot(gs[0]) 
ax.set_ylim([0,12]) 

#slider 
frame=0 
axframe = plt.axes([0.13, 0.02, 0.75, 0.03]) 
sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d') 

#plots 
ln1,=ax.plot(dfs[0].index,dfs[0]['col1']) 
ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black') 

#artists 
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5) 
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5) 
ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5) 
ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black') 

#update plots 
def update(val): 
    frame = np.floor(sframe.val) 
    ln1.set_ydata(dfs[frame]['col1']) 
    ln2.set_ydata(dfs[frame]['col2']) 
    ax.set_title('Frame ' + str(int(frame))) 
    plt.draw() 

#connect callback to slider 
sframe.on_changed(update) 
plt.show() 

这是它看起来像此刻 enter image description here

我不能采用同样的方法为plot(),因为下面产生的错误信息:

ln3,=ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5) 
TypeError: 'PolyCollection' object is not iterable 

这是什么意思看起来像在每一帧 enter image description here

回答

1

fill_between返回一个PolyCollection,它需要一个创建时的顶点列表(或多个列表)。不幸的是,我还没有找到一种方法来检索用于创建给定PolyCollection的顶点,但在您的情况下,很容易直接创建PolyCollection(从而避免使用fill_between),然后在更改帧时更新顶点。

下面一个版本的代码已经做了你所追求的:

import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
import matplotlib.gridspec as gridspec 
from matplotlib.widgets import Slider 

from matplotlib.collections import PolyCollection 

#create dataframes 
dfs={} 
for x in range(100): 
    col1=np.random.normal(10,0.5,30) 
    col2=(np.repeat([5,8,7],np.round(np.random.dirichlet(np.ones(3),size=1)*31)[0].tolist()))[:30] 
    col3=np.random.randint(4,size=30) 
    dfs[x]=pd.DataFrame({'col1':col1,'col2':col2,'col3':col3}) 

#create figure,axis,subplot 
fig = plt.figure() 
gs = gridspec.GridSpec(1,1,hspace=0,wspace=0,left=0.1,bottom=0.1) 
ax = plt.subplot(gs[0]) 
ax.set_ylim([0,12]) 

#slider 
frame=0 
axframe = plt.axes([0.13, 0.02, 0.75, 0.03]) 
sframe = Slider(axframe, 'frame', 0, 99, valinit=0,valfmt='%d') 

#plots 
ln1,=ax.plot(dfs[0].index,dfs[0]['col1']) 
ln2,=ax.plot(dfs[0].index,dfs[0]['col2'],c='black') 

##additional code to update the PolyCollections 
val_r = 5 
val_b = 8 
val_g = 7 

def update_collection(collection, value, frame = 0): 
    xs = np.array(dfs[frame].index) 
    ys = np.array(dfs[frame]['col2']) 

    ##we need to catch the case where no points with y == value exist: 
    try: 
     minx = np.min(xs[ys == value]) 
     maxx = np.max(xs[ys == value]) 
     miny = value-0.5 
     maxy = value+0.5 
     verts = np.array([[minx,miny],[maxx,miny],[maxx,maxy],[minx,maxy]]) 
    except ValueError: 
     verts = np.zeros((0,2)) 
    finally: 
     collection.set_verts([verts]) 

#artists 

##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==5,facecolor='r',edgecolors='none',alpha=0.5) 
reds = PolyCollection([],facecolors = ['r'], alpha = 0.5) 
ax.add_collection(reds) 
update_collection(reds,val_r) 

##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==8,facecolor='b',edgecolors='none',alpha=0.5) 
blues = PolyCollection([],facecolors = ['b'], alpha = 0.5) 
ax.add_collection(blues) 
update_collection(blues, val_b) 

##ax.fill_between(dfs[0].index,y1=dfs[0]['col2']-0.5,y2=dfs[0]['col2']+0.5,where=dfs[0]['col2']==7,facecolor='g',edgecolors='none',alpha=0.5) 
greens = PolyCollection([],facecolors = ['g'], alpha = 0.5) 
ax.add_collection(greens) 
update_collection(greens, val_g) 

ax.vlines(x=dfs[0]['col3'].index,ymin=0,ymax=dfs[0]['col3'],color='black') 

#update plots 
def update(val): 
    frame = np.floor(sframe.val) 
    ln1.set_ydata(dfs[frame]['col1']) 
    ln2.set_ydata(dfs[frame]['col2']) 
    ax.set_title('Frame ' + str(int(frame))) 

    ##updating the PolyCollections: 
    update_collection(reds,val_r, frame) 
    update_collection(blues,val_b, frame) 
    update_collection(greens,val_g, frame) 

    plt.draw() 

#connect callback to slider 
sframe.on_changed(update) 
plt.show() 

每三个PolyCollectionsredsblues,并且greens)只有四个顶点(矩形的边缘),其根据给定的数据确定(在update_collections中完成)。结果是这样的:

example result of given code

测试在Python 3.5

相关问题