问题描述
我有一个 Pandas 数据框,如:
ticket date close
0 AAA 2018-01-12 176.16
1 AAA 2018-01-13 176.49
3 AAA 2018-01-14 176.00
4 BBB 2018-01-12 78.19
5 BBB 2018-01-13 79.90
6 BBB 2018-01-14 78.10
我有一个功能:
def rsi(dataframe, period, column = 'close'):
delta = dataframe[column].diff()
up, down = delta.copy(), delta.copy()
up[up < 0] = 0
down[down > 0] = 0
rolling_up = up.ewm(com=period - 1, adjust=False).mean()
rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
rsi = 100 - 100 / (1 + rolling_up / rolling_down)
dataframe['rsi'] = rsi
return dataframe
我需要的是将此函数应用于每个 groupby('ticket') 的数据帧。 我试过这个,但它不起作用。 请给我一些建议。
print(dataframe.groupby('ticket').apply(rsi, 2))
我收到一个错误:
无法从重复的轴重新索引
整个源代码是:
# -*- coding: utf-8 -*-
import json
import pandas
import requests
import datetime
def get_historical_prices(tickets, range):
request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
united_dataframe = pandas.DataFrame()
for symbol in json:
ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
ticket_dataframe.insert(0, 'ticket', symbol)
united_dataframe = united_dataframe.append(ticket_dataframe)
return united_dataframe[['ticket', 'date', 'close']]
def rsi(dataframe, period, column = 'close'):
delta = all_prices[column].diff()
up, down = delta.copy(), delta.copy()
up[up < 0] = 0
down[down > 0] = 0
rolling_up = up.ewm(com=period - 1, adjust=False).mean()
rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
rsi = 100 - 100 / (1 + rolling_up / rolling_down)
dataframe['rsi'] = rsi
return dataframe
# Get the data
tickets = ['AAPL', 'FB', 'TSLA']
all_prices = get_historical_prices(tickets, '1m')
print(all_prices.groupby('ticket').apply(rsi, 2))
1楼
源代码有问题。 线
delta = all_prices[column].diff()
应该
delta = dataframe[column].diff()
修复它也将毫无问题地运行。
重新分配会将列rsi
添加到all_prices
即
all_prices = all_prices.groupby('ticket').apply(rsi, 2)
所以最终的鳕鱼和结果如下所示
In [20]: # -*- coding: utf-8 -*-
...:
...: import json
...: import pandas
...: import requests
...: import datetime
...:
...: def get_historical_prices(tickets, range):
...: request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
...: json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
...: united_dataframe = pandas.DataFrame()
...: for symbol in json:
...: ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
...: ticket_dataframe.insert(0, 'ticket', symbol)
...: united_dataframe = united_dataframe.append(ticket_dataframe)
...: return united_dataframe[['ticket', 'date', 'close']]
...:
...: def rsi(dataframe, period, column = 'close'):
...: delta = dataframe[column].diff()
...: up, down = delta.copy(), delta.copy()
...: up[up < 0] = 0
...: down[down > 0] = 0
...: rolling_up = up.ewm(com=period - 1, adjust=False).mean()
...: rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
...: rsi = 100 - 100 / (1 + rolling_up / rolling_down)
...: dataframe['rsi'] = rsi
...: return dataframe
...:
...: # Get the data
...: tickets = ['AAPL', 'FB', 'TSLA']
...: all_prices = get_historical_prices(tickets, '1m')
...:
...: all_prices = all_prices.groupby('ticket').apply(rsi, 2)
...: print(all_prices.head())
...:
...:
ticket date close rsi
0 AAPL 2018-01-12 177.09 NaN
1 AAPL 2018-01-16 176.19 0.000000
2 AAPL 2018-01-17 179.10 76.377953
3 AAPL 2018-01-18 179.26 78.208232
4 AAPL 2018-01-19 178.46 44.065484
2楼
这里的问题与线路有关
dataframe['rsi'] = rsi
return dataframe
问题是 rsi 与数据帧的索引不同,而且 rsi 的长度不同
我将上面的几行更改为
return rsi
并且代码运行没有问题
所以最终的鳕鱼和结果如下所示
In [12]: # -*- coding: utf-8 -*-
...:
...: import json
...: import pandas
...: import requests
...: import datetime
...:
...: def get_historical_prices(tickets, range):
...: request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
...: json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
...: united_dataframe = pandas.DataFrame()
...: for symbol in json:
...: ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
...: ticket_dataframe.insert(0, 'ticket', symbol)
...: united_dataframe = united_dataframe.append(ticket_dataframe)
...: return united_dataframe[['ticket', 'date', 'close']]
...:
...: def rsi(dataframe, period, column = 'close'):
...: delta = all_prices[column].diff()
...: up, down = delta.copy(), delta.copy()
...: up[up < 0] = 0
...: down[down > 0] = 0
...: rolling_up = up.ewm(com=period - 1, adjust=False).mean()
...: rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
...: rsi = 100 - 100 / (1 + rolling_up / rolling_down)
...:
...: return rsi
...:
...: # Get the data
...: tickets = ['AAPL', 'FB', 'TSLA']
...: all_prices = get_historical_prices(tickets, '1m')
...:
...: print(all_prices.groupby('ticket').apply(rsi, 2))
...:
...:
close 0 1 2 3 4 5 6 \
ticket
AAPL NaN 0.0 76.377953 78.208232 44.065484 16.991057 19.694656
FB NaN 0.0 76.377953 78.208232 44.065484 16.991057 19.694656
TSLA NaN 0.0 76.377953 78.208232 44.065484 16.991057 19.694656
close 7 8 9 ... 11 12 \
ticket ...
AAPL 3.521704 1.252711 15.2917 ... 76.523444 48.103572
FB 3.521704 1.252711 15.2917 ... 76.523444 48.103572
TSLA 3.521704 1.252711 15.2917 ... 76.523444 48.103572
close 13 14 15 16 17 18 \
ticket
AAPL 80.777497 46.146025 23.884925 8.34273 16.897125 75.919398
FB 80.777497 46.146025 23.884925 8.34273 16.897125 75.919398
TSLA 80.777497 46.146025 23.884925 8.34273 16.897125 75.919398
close 19 20
ticket
AAPL 15.705838 12.501725
FB 15.705838 12.501725
TSLA 15.705838 12.501725
[3 rows x 61 columns]