本节将讲述应用 Connection、MyqLOperator、XComs 来实现一个残缺的airflow ETL。
一、将数据存入数据库的原始办法
1、创立表
CREATE database demodb;use demodb;create table stock_prices_stage(ticker varchar(30),as_of_date date,open_price double,high_price double,low_price double,close_price double) COMMENT = '股票价格缓冲区表';create table stock_prices(id int not null AUTO_INCREMENT,ticker varchar(30),as_of_date date COMMENT '以后日期',open_price double,high_price double,low_price double,close_price double,created_at timestamp default now(),updated_at timestamp default now(),primary key (id))COMMENT = '股票价格表';create index ids_stockprices on stock_prices(ticker, as_of_date);create index ids_stockpricestage on stock_prices_stage(ticker, as_of_date);
二、应用 airflow Connection 治理数据库连贯信息
在上一节代码的根底上,将保留到文件的数据转存到数据库中,V2版本的代码如下:
download_stock_price_v2.py
2.1 传统连贯办法
"""Example DAG demonstrating the usage of the BashOperator."""from datetime import timedeltafrom textwrap import dedentimport yfinance as yffrom airflow import DAGfrom airflow.operators.python import PythonOperatorfrom airflow.utils.dates import days_agofrom airflow.models import Variableimport mysql.connectordef download_price(*args, **context): stock_list = get_tickers(context) for ticker in stock_list: dat = yf.Ticker(ticker) hist = dat.history(period="1mo") # print(type(hist)) # print(hist.shape) # print(os.getcwd()) with open(get_file_path(ticker), 'w') as writer: hist.to_csv(writer, index=True) print("Finished downloading price data for " + ticker)def get_file_path(ticker): # NOT SAVE in distributed system return f'./{ticker}.csv'def load_price_data(ticker): with open(get_file_path(ticker), 'r') as reader: lines = reader.readlines() return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']def get_tickers(context): # 获取配置的变量Variables stock_list = Variable.get("stock_list_json", deserialize_json=True) # 如果有配置参数,则应用配置参数的数据(Trigger DAG with parameters) stocks = context["dag_run"].conf.get("stocks") if stocks: stock_list = stocks return stock_listdef save_to_mysql_stage(*args, **context): tickers = get_tickers(context) # 连贯数据库 mydb = mysql.connector.connect( host="98.14.13.15", user="root", password="Quant888", database="demodb", port=3307 ) mycursor = mydb.cursor() for ticker in tickers: val = load_price_data(ticker) print(f"{ticker} length={len(val)} {val[1]}") sql = """INSERT INTO stock_prices_stage (ticker, as_of_date, open_price, high_price, low_price, close_price) VALUES (%s,%s,%s,%s,%s,%s)""" mycursor.executemany(sql, val) mydb.commit() print(mycursor.rowcount, "record inserted.")default_args = { 'owner': 'airflow'}# [START instantiate_dag]with DAG( dag_id='download_stock_price_v2', default_args=default_args, description='download stock price and save to local csv files and save to database', schedule_interval=None, start_date=days_ago(2), tags=['quantdata'],) as dag: # [END instantiate_dag] dag.doc_md = """ This DAG download stock price """ download_task = PythonOperator( task_id="download_prices", python_callable=download_price, provide_context=True ) save_to_mysql_task = PythonOperator( task_id="save_to_database", python_callable=save_to_mysql_stage, provide_context=True ) download_task >> save_to_mysql_task
而后在 airflow 后盾手动触发执行,前两次执行失败,后边调试后,执行胜利了
能够看到数据曾经入库了:
2.2 airflow Connection治理连贯信息
上边的demo有些问题,将数据库的连贯间接硬编码到代码中了,这样前期保护不是很好,airflow给咱们提供了 Connections 连贯办法,能够应用该办法将连贯信息间接写入到这里即可。
抉择连贯类型,短少了MySQL连贯类型:
Conn Type missing? Make sure you've installed the corresponding Airflow Provider Package.
请看官网文档:
https://airflow.apache.org/do...
https://airflow.apache.org/do...
https://airflow.apache.org/do...
$ pip install apache-airflow-providers-mysql
而后从新刷新连贯页面,能够看到连贯类型 MySQL 曾经呈现了:
而后填入相干的数据库连贯信息:
而后对代码进行批改:
def save_to_mysql_stage(*args, **context): tickers = get_tickers(context) """ # 连贯数据库(硬编码方式连贯) mydb = mysql.connector.connect( host="98.14.14.145", user="root", password="Quant888", database="demodb", port=3307 ) """ # 应用airflow 的 Connections 动静获取配置信息 from airflow.hooks.base_hook import BaseHook conn = BaseHook.get_connection('demodb') mydb = mysql.connector.connect( host=conn.host, user=conn.login, password=conn.password, database=conn.schema, port=conn.port ) mycursor = mydb.cursor() for ticker in tickers: val = load_price_data(ticker) print(f"{ticker} length={len(val)} {val[1]}") sql = """INSERT INTO stock_prices_stage (ticker, as_of_date, open_price, high_price, low_price, close_price) VALUES (%s,%s,%s,%s,%s,%s)""" mycursor.executemany(sql, val) mydb.commit() print(mycursor.rowcount, "record inserted.")
残缺代码:
"""Example DAG demonstrating the usage of the BashOperator."""from datetime import timedeltafrom textwrap import dedentimport yfinance as yffrom airflow import DAGfrom airflow.operators.python import PythonOperatorfrom airflow.utils.dates import days_agofrom airflow.models import Variableimport mysql.connectordef download_price(*args, **context): stock_list = get_tickers(context) for ticker in stock_list: dat = yf.Ticker(ticker) hist = dat.history(period="1mo") # print(type(hist)) # print(hist.shape) # print(os.getcwd()) with open(get_file_path(ticker), 'w') as writer: hist.to_csv(writer, index=True) print("Finished downloading price data for " + ticker)def get_file_path(ticker): # NOT SAVE in distributed system return f'./{ticker}.csv'def load_price_data(ticker): with open(get_file_path(ticker), 'r') as reader: lines = reader.readlines() return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']def get_tickers(context): # 获取配置的变量Variables stock_list = Variable.get("stock_list_json", deserialize_json=True) # 如果有配置参数,则应用配置参数的数据(Trigger DAG with parameters) stocks = context["dag_run"].conf.get("stocks") if stocks: stock_list = stocks return stock_listdef save_to_mysql_stage(*args, **context): tickers = get_tickers(context) """ # 连贯数据库(硬编码方式连贯) mydb = mysql.connector.connect( host="98.14.13.14", user="root", password="Quan888", database="demodb", port=3307 ) """ # 应用airflow 的 Connections 动静获取配置信息 from airflow.hooks.base_hook import BaseHook conn = BaseHook.get_connection('demodb') mydb = mysql.connector.connect( host=conn.host, user=conn.login, password=conn.password, database=conn.schema, port=conn.port ) mycursor = mydb.cursor() for ticker in tickers: val = load_price_data(ticker) print(f"{ticker} length={len(val)} {val[1]}") sql = """INSERT INTO stock_prices_stage (ticker, as_of_date, open_price, high_price, low_price, close_price) VALUES (%s,%s,%s,%s,%s,%s)""" mycursor.executemany(sql, val) mydb.commit() print(mycursor.rowcount, "record inserted.")default_args = { 'owner': 'airflow'}# [START instantiate_dag]with DAG( dag_id='download_stock_price_v2', default_args=default_args, description='download stock price and save to local csv files and save to database', schedule_interval=None, start_date=days_ago(2), tags=['quantdata'],) as dag: # [END instantiate_dag] dag.doc_md = """ This DAG download stock price """ download_task = PythonOperator( task_id="download_prices", python_callable=download_price, provide_context=True ) save_to_mysql_task = PythonOperator( task_id="save_to_database", python_callable=save_to_mysql_stage, provide_context=True ) download_task >> save_to_mysql_task
三、应用 MyqLOperator 执行数据库操作
在 dags/
目录下新建sql文件,用来合并缓冲表(stage)的数据到正式表。
merge_stock_price.sql
-- update the existing rowsUPDATE stock_prices p, stock_prices_stage sSET p.open_price = s.open_price, p.high_price = s.high_price, p.low_price = s.low_price, p.close_price = s.close_price, updated_at = now()WHERE p.ticker = s.tickerAND p.as_of_date = s.as_of_date;-- inserting new rowsINSERT INTO stock_prices(ticker,as_of_date,open_price,high_price,low_price,close_price)SELECT ticker,as_of_date,open_price,high_price,low_price,close_priceFROM stock_prices_stage sWHERE NOT EXISTS(SELECT 1 FROM stock_prices p WHERE p.ticker = s.ticker AND p.as_of_date = s.as_of_date);-- truncate the stage table;TRUNCATE TABLE stock_prices_stage;
在 download_stock_price_v2.py
文件新建 MySQL task 工作:
须要先引入:
from airflow.providers.mysql.operators.mysql import MySqlOperator
mysql_task = MySqlOperator( task_id="merge_stock_price", mysql_conn_id='demodb', sql="merge_stock_price.sql", dag=dag, ) download_task >> save_to_mysql_task >> mysql_task
残缺代码:
"""Example DAG demonstrating the usage of the BashOperator."""from datetime import timedeltafrom textwrap import dedentimport yfinance as yfimport mysql.connectorfrom airflow import DAGfrom airflow.operators.python import PythonOperator# from airflow.operators.mysql_operator import MySqlOperatorfrom airflow.providers.mysql.operators.mysql import MySqlOperatorfrom airflow.utils.dates import days_agofrom airflow.models import Variabledef download_price(*args, **context): stock_list = get_tickers(context) for ticker in stock_list: dat = yf.Ticker(ticker) hist = dat.history(period="1mo") # print(type(hist)) # print(hist.shape) # print(os.getcwd()) with open(get_file_path(ticker), 'w') as writer: hist.to_csv(writer, index=True) print("Finished downloading price data for " + ticker)def get_file_path(ticker): # NOT SAVE in distributed system return f'./{ticker}.csv'def load_price_data(ticker): with open(get_file_path(ticker), 'r') as reader: lines = reader.readlines() return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']def get_tickers(context): # 获取配置的变量Variables stock_list = Variable.get("stock_list_json", deserialize_json=True) # 如果有配置参数,则应用配置参数的数据(Trigger DAG with parameters) stocks = context["dag_run"].conf.get("stocks") if stocks: stock_list = stocks return stock_listdef save_to_mysql_stage(*args, **context): tickers = get_tickers(context) """ # 连贯数据库(硬编码方式连贯) mydb = mysql.connector.connect( host="98.14.14.15", user="root", password="Quan888", database="demodb", port=3307 ) """ # 应用airflow 的 Connections 动静获取配置信息 from airflow.hooks.base_hook import BaseHook conn = BaseHook.get_connection('demodb') mydb = mysql.connector.connect( host=conn.host, user=conn.login, password=conn.password, database=conn.schema, port=conn.port ) mycursor = mydb.cursor() for ticker in tickers: val = load_price_data(ticker) print(f"{ticker} length={len(val)} {val[1]}") sql = """INSERT INTO stock_prices_stage (ticker, as_of_date, open_price, high_price, low_price, close_price) VALUES (%s,%s,%s,%s,%s,%s)""" mycursor.executemany(sql, val) mydb.commit() print(mycursor.rowcount, "record inserted.")default_args = { 'owner': 'airflow'}# [START instantiate_dag]with DAG( dag_id='download_stock_price_v2', default_args=default_args, description='download stock price and save to local csv files and save to database', schedule_interval=None, start_date=days_ago(2), tags=['quantdata'],) as dag: # [END instantiate_dag] dag.doc_md = """ This DAG download stock price """ download_task = PythonOperator( task_id="download_prices", python_callable=download_price, provide_context=True ) save_to_mysql_task = PythonOperator( task_id="save_to_database", python_callable=save_to_mysql_stage, provide_context=True ) mysql_task = MySqlOperator( task_id="merge_stock_price", mysql_conn_id='demodb', sql="merge_stock_price.sql", dag=dag, ) download_task >> save_to_mysql_task >> mysql_task
而后手动执行airflow,能够看到曾经执行胜利了:
而后看相干表数据,也曾经更新胜利了
四、应用 XComs 在工作之间传递数据
XComs 概念
XComs(“穿插通信”的缩写)是一种让工作互相通信的机制,因为默认状况下工作是齐全隔离的,并且可能运行在齐全不同的机器上。
XCom 由一个键(实质上是它的名称)以及它来自的 task_id 和 dag_id 标识。它们能够具备任何(可序列化的)值,但它们仅实用于大量数据;不要应用它们来传递大值,例如数据帧。
简略一句话,XComs能够在多个task之间进行通信(数据的传递)。
XComs are explicitly "pushed" and "pulled" to/from their storage using the xcom_push and xcom_pull methods on Task Instances. Many operators will auto-push their results into an XCom key called return_value
if the do_xcom_push argument is set to True (as it is by default), and @task functions do this as well.
# Pulls the return_value XCOM from "pushing_task"value = task_instance.xcom_pull(task_ids='pushing_task')
实战利用
应用场景:减少一支不存在股票,而后对这只股票进行验证,存在的股票才能够传入到后边。
批改 download_stock_price_v2.py
文件下载代码:
而后将股票保留到MySQL stage 时,通过上一步返回的股票来获取曾经过滤的ticker。
download_stock_price_v2.py
残缺代码
"""Example DAG demonstrating the usage of the BashOperator."""from datetime import timedeltafrom textwrap import dedentimport yfinance as yfimport mysql.connectorfrom airflow import DAGfrom airflow.operators.python import PythonOperator# from airflow.operators.mysql_operator import MySqlOperatorfrom airflow.providers.mysql.operators.mysql import MySqlOperatorfrom airflow.utils.dates import days_agofrom airflow.models import Variabledef download_price(*args, **context): stock_list = get_tickers(context) # 新增失常的股票(没有退市的或不存在的) valid_tickers = [] for ticker in stock_list: dat = yf.Ticker(ticker) hist = dat.history(period="1mo") # print(type(hist)) # print(hist.shape) # print(os.getcwd()) if hist.shape[0] > 0: valid_tickers.append(ticker) else: continue with open(get_file_path(ticker), 'w') as writer: hist.to_csv(writer, index=True) print("Finished downloading price data for " + ticker) # 减少返回值(用于工作之间数据的传递) return valid_tickersdef get_file_path(ticker): # NOT SAVE in distributed system return f'./{ticker}.csv'def load_price_data(ticker): with open(get_file_path(ticker), 'r') as reader: lines = reader.readlines() return [[ticker] + line.split(',')[:5] for line in lines if line[:4] != 'Date']def get_tickers(context): # 获取配置的变量Variables stock_list = Variable.get("stock_list_json", deserialize_json=True) # 如果有配置参数,则应用配置参数的数据(Trigger DAG with parameters) stocks = context["dag_run"].conf.get("stocks") if stocks: stock_list = stocks return stock_listdef save_to_mysql_stage(*args, **context): # tickers = get_tickers(context) # Pull the return_value XCOM from "pulling_task" tickers = context['ti'].xcom_pull(task_ids='download_prices') print(f"received tickers:{tickers}") """ # 连贯数据库(硬编码方式连贯) mydb = mysql.connector.connect( host="98.14.14.15", user="root", password="Quant888", database="demodb", port=3307 ) """ # 应用airflow 的 Connections 动静获取配置信息 from airflow.hooks.base_hook import BaseHook conn = BaseHook.get_connection('demodb') mydb = mysql.connector.connect( host=conn.host, user=conn.login, password=conn.password, database=conn.schema, port=conn.port ) mycursor = mydb.cursor() for ticker in tickers: val = load_price_data(ticker) print(f"{ticker} length={len(val)} {val[1]}") sql = """INSERT INTO stock_prices_stage (ticker, as_of_date, open_price, high_price, low_price, close_price) VALUES (%s,%s,%s,%s,%s,%s)""" mycursor.executemany(sql, val) mydb.commit() print(mycursor.rowcount, "record inserted.")default_args = { 'owner': 'airflow'}# [START instantiate_dag]with DAG( dag_id='download_stock_price_v2', default_args=default_args, description='download stock price and save to local csv files and save to database', schedule_interval=None, start_date=days_ago(2), tags=['quantdata'],) as dag: # [END instantiate_dag] dag.doc_md = """ This DAG download stock price """ download_task = PythonOperator( task_id="download_prices", python_callable=download_price, provide_context=True ) save_to_mysql_task = PythonOperator( task_id="save_to_database", python_callable=save_to_mysql_stage, provide_context=True ) mysql_task = MySqlOperator( task_id="merge_stock_price", mysql_conn_id='demodb', sql="merge_stock_price.sql", dag=dag, ) download_task >> save_to_mysql_task >> mysql_task
而后在 Variables
减少一个不存在的 ticker(FBXXOO),以此来验证Xcom数据传递进行验证:
手动执行DAG,能够通过日志打印看到曾经获取到了 Xcom tickers = context['ti'].xcom_pull(task_ids='download_prices')
上一个工作传递过去的数据了。
相干文章:
Airflow 相干概念文档
Airflow XComs官网文档