当前位置: 代码迷 >> python >> 正确将自己的函数应用于分组的熊猫数据框
  详细解决方案

正确将自己的函数应用于分组的熊猫数据框

热度:50   发布时间:2023-07-16 10:46:53.0

我有一个 Pandas 数据框,如:

   ticket date         close  
0    AAA  2018-01-12  176.16
1    AAA  2018-01-13  176.49
3    AAA  2018-01-14  176.00
4    BBB  2018-01-12  78.19
5    BBB  2018-01-13  79.90
6    BBB  2018-01-14  78.10

我有一个功能:

def rsi(dataframe, period, column = 'close'):
    delta = dataframe[column].diff()
    up, down = delta.copy(), delta.copy()
    up[up < 0] = 0
    down[down > 0] = 0
    rolling_up = up.ewm(com=period - 1, adjust=False).mean()
    rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
    rsi = 100 - 100 / (1 + rolling_up / rolling_down)
    dataframe['rsi'] =  rsi
    return dataframe

我需要的是将此函数应用于每个 groupby('ticket') 的数据帧。 我试过这个,但它不起作用。 请给我一些建议。

print(dataframe.groupby('ticket').apply(rsi, 2))

我收到一个错误:

无法从重复的轴重新索引

整个源代码是:

# -*- coding: utf-8 -*-

import json
import pandas
import requests
import datetime

def get_historical_prices(tickets, range):
    request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
    json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
    united_dataframe = pandas.DataFrame()
    for symbol in json:
        ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
        ticket_dataframe.insert(0, 'ticket', symbol)
        united_dataframe = united_dataframe.append(ticket_dataframe)
    return united_dataframe[['ticket', 'date', 'close']]

def rsi(dataframe, period, column = 'close'):
    delta = all_prices[column].diff()
    up, down = delta.copy(), delta.copy()
    up[up < 0] = 0
    down[down > 0] = 0
    rolling_up = up.ewm(com=period - 1, adjust=False).mean()
    rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
    rsi = 100 - 100 / (1 + rolling_up / rolling_down)
    dataframe['rsi'] =  rsi
    return dataframe

# Get the data
tickets = ['AAPL', 'FB', 'TSLA']
all_prices = get_historical_prices(tickets, '1m')

print(all_prices.groupby('ticket').apply(rsi, 2))

源代码有问题。 线

delta = all_prices[column].diff()

应该

delta = dataframe[column].diff() 

修复它也将毫无问题地运行。 重新分配会将列rsi添加到all_prices

all_prices = all_prices.groupby('ticket').apply(rsi, 2)

所以最终的鳕鱼和结果如下所示

In [20]: # -*- coding: utf-8 -*-
    ...: 
    ...: import json
    ...: import pandas
    ...: import requests
    ...: import datetime
    ...: 
    ...: def get_historical_prices(tickets, range):
    ...:     request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
    ...:     json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
    ...:     united_dataframe = pandas.DataFrame()
    ...:     for symbol in json:
    ...:         ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
    ...:         ticket_dataframe.insert(0, 'ticket', symbol)
    ...:         united_dataframe = united_dataframe.append(ticket_dataframe)
    ...:     return united_dataframe[['ticket', 'date', 'close']]
    ...: 
    ...: def rsi(dataframe, period, column = 'close'):
    ...:     delta = dataframe[column].diff()
    ...:     up, down = delta.copy(), delta.copy()
    ...:     up[up < 0] = 0
    ...:     down[down > 0] = 0
    ...:     rolling_up = up.ewm(com=period - 1, adjust=False).mean()
    ...:     rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
    ...:     rsi = 100 - 100 / (1 + rolling_up / rolling_down)
    ...:     dataframe['rsi'] = rsi
    ...:     return dataframe
    ...: 
    ...: # Get the data
    ...: tickets = ['AAPL', 'FB', 'TSLA']
    ...: all_prices = get_historical_prices(tickets, '1m')
    ...: 
    ...: all_prices = all_prices.groupby('ticket').apply(rsi, 2)
    ...: print(all_prices.head())
    ...: 
    ...: 
  ticket        date   close        rsi
0   AAPL  2018-01-12  177.09        NaN
1   AAPL  2018-01-16  176.19   0.000000
2   AAPL  2018-01-17  179.10  76.377953
3   AAPL  2018-01-18  179.26  78.208232
4   AAPL  2018-01-19  178.46  44.065484

这里的问题与线路有关

dataframe['rsi'] =  rsi
return dataframe

问题是 rsi 与数据帧的索引不同,而且 rsi 的长度不同

我将上面的几行更改为

return rsi

并且代码运行没有问题

所以最终的鳕鱼和结果如下所示

In [12]: # -*- coding: utf-8 -*-
    ...: 
    ...: import json
    ...: import pandas
    ...: import requests
    ...: import datetime
    ...: 
    ...: def get_historical_prices(tickets, range):
    ...:     request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
    ...:     json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
    ...:     united_dataframe = pandas.DataFrame()
    ...:     for symbol in json:
    ...:         ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
    ...:         ticket_dataframe.insert(0, 'ticket', symbol)
    ...:         united_dataframe = united_dataframe.append(ticket_dataframe)
    ...:     return united_dataframe[['ticket', 'date', 'close']]
    ...: 
    ...: def rsi(dataframe, period, column = 'close'):
    ...:     delta = all_prices[column].diff()
    ...:     up, down = delta.copy(), delta.copy()
    ...:     up[up < 0] = 0
    ...:     down[down > 0] = 0
    ...:     rolling_up = up.ewm(com=period - 1, adjust=False).mean()
    ...:     rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
    ...:     rsi = 100 - 100 / (1 + rolling_up / rolling_down)
    ...:     
    ...:     return rsi
    ...: 
    ...: # Get the data
    ...: tickets = ['AAPL', 'FB', 'TSLA']
    ...: all_prices = get_historical_prices(tickets, '1m')
    ...: 
    ...: print(all_prices.groupby('ticket').apply(rsi, 2))
    ...: 
    ...: 
close   0    1          2          3          4          5          6   \
ticket                                                                   
AAPL   NaN  0.0  76.377953  78.208232  44.065484  16.991057  19.694656   
FB     NaN  0.0  76.377953  78.208232  44.065484  16.991057  19.694656   
TSLA   NaN  0.0  76.377953  78.208232  44.065484  16.991057  19.694656   

close         7         8        9     ...             11         12  \
ticket                                 ...                             
AAPL    3.521704  1.252711  15.2917    ...      76.523444  48.103572   
FB      3.521704  1.252711  15.2917    ...      76.523444  48.103572   
TSLA    3.521704  1.252711  15.2917    ...      76.523444  48.103572   

close          13         14         15       16         17         18  \
ticket                                                                   
AAPL    80.777497  46.146025  23.884925  8.34273  16.897125  75.919398   
FB      80.777497  46.146025  23.884925  8.34273  16.897125  75.919398   
TSLA    80.777497  46.146025  23.884925  8.34273  16.897125  75.919398   

close          19         20  
ticket                        
AAPL    15.705838  12.501725  
FB      15.705838  12.501725  
TSLA    15.705838  12.501725  

[3 rows x 61 columns]
  相关解决方案